In [1]:
import os
from zipfile import ZipFile
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import gdown

In [2]:
physical_devices = tf.config.list_physical_devices("GPU")
tf.config.experimental.set_memory_growth(physical_devices[0], True)

In [4]:
dataset = keras.preprocessing.image_dataset_from_directory(
    directory="celeb_dataset/img_align_celeba", label_mode=None, image_size=(64, 64), batch_size=64,
    shuffle=True
).map(lambda x: x / 255.0)

Found 202599 files belonging to 1 classes.


In [5]:
folder = 'generated_images'
for filename in os.listdir(folder):
    file_path = os.path.join(folder, filename)
    try:
        if os.path.isfile(file_path) or os.path.islink(file_path):
            os.unlink(file_path)
    except Exception as e:
        print('Failed to delete %s. Reason: %s' % (file_path, e))

In [6]:
discriminator = keras.Sequential(
    [
        keras.Input(shape=(64, 64, 3)),
        layers.Conv2D(64, kernel_size=4, strides=2, padding="same"),
        layers.LeakyReLU(0.2),
        layers.Conv2D(128, kernel_size=4, strides=2, padding="same"),
        layers.LeakyReLU(0.2),
        layers.Conv2D(128, kernel_size=4, strides=2, padding="same"),
        layers.LeakyReLU(0.2),
        layers.Flatten(),
        layers.Dropout(0.2),
        layers.Dense(1, activation="sigmoid"),
    ]
)

print(discriminator.summary())

latent_dim = 128
generator = keras.Sequential(
    [
        layers.Input(shape=(latent_dim,)),
        layers.Dense(8 * 8 * 128),
        layers.Reshape((8, 8, 128)),
        layers.Conv2DTranspose(128, kernel_size=4, strides=2, padding="same"),
        layers.LeakyReLU(0.2),
        layers.Conv2DTranspose(256, kernel_size=4, strides=2, padding="same"),
        layers.LeakyReLU(0.2),
        layers.Conv2DTranspose(512, kernel_size=4, strides=2, padding="same"),
        layers.LeakyReLU(0.2),
        layers.Conv2D(3, kernel_size=5, padding="same", activation="sigmoid"),
    ]
)

print(generator.summary())

opt_gen = keras.optimizers.Adam(1e-4)
opt_disc = keras.optimizers.Adam(1e-4)

loss_fn = keras.losses.BinaryCrossentropy()


Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d (Conv2D)             (None, 32, 32, 64)        3136      
                                                                 
 leaky_re_lu (LeakyReLU)     (None, 32, 32, 64)        0         
                                                                 
 conv2d_1 (Conv2D)           (None, 16, 16, 128)       131200    
                                                                 
 leaky_re_lu_1 (LeakyReLU)   (None, 16, 16, 128)       0         
                                                                 
 conv2d_2 (Conv2D)           (None, 8, 8, 128)         262272    
                                                                 
 leaky_re_lu_2 (LeakyReLU)   (None, 8, 8, 128)         0         
                                                                 
 flatten (Flatten)           (None, 8192)              0

In [7]:
checkpoint_path = "checkpoint/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

# epoch counter
ep_cnt = tf.Variable(initial_value=0, trainable=False, dtype=tf.int64)

checkpoint = tf.train.Checkpoint(**dict(generator=generator, discriminator=discriminator, opt_gen=opt_gen, opt_disc=opt_disc, ep_cnt= ep_cnt))

cp_manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, 5)

# get latest checkpoint
try:
    checkpoint.restore(cp_manager.latest_checkpoint).assert_existing_objects_matched()
except Exception as e:
    print(e)


In [10]:
print(ep_cnt)

<tf.Variable 'Variable:0' shape=() dtype=int64, numpy=0>


In [None]:
for epoch in range(25):
    print(f"start epoch {epoch}")
    if epoch < ep_cnt:
        continue
    
    for idx, real in enumerate(tqdm(dataset)):
        batch_size = real.shape[0]
        random_latent_vectors = tf.random.normal(shape=(batch_size, latent_dim))

        fake = generator(random_latent_vectors)

        if idx % 100 == 0:
            img = keras.preprocessing.image.array_to_img(fake[0])
            img.save(f'generated_images/generated_image_{epoch}_i{idx}.png')
        # train discriminator
        with tf.GradientTape() as disc_tape:
            loss_disc_real = loss_fn(tf.ones((batch_size, 1)), discriminator(real))
            loss_disc_fake = loss_fn(tf.zeros((batch_size, 1)), discriminator(fake))
            loss_disc = (loss_disc_fake + loss_disc_real)/2

        grads = disc_tape.gradient(loss_disc, discriminator.trainable_weights)
        opt_disc.apply_gradients(
            zip(grads, discriminator.trainable_weights)
        )

        # train generator
        with tf.GradientTape() as gen_tape:
            fake = generator(random_latent_vectors)
            output = discriminator(fake)
            loss_gen = loss_fn(tf.ones(batch_size, 1), output)

        grads = gen_tape.gradient(loss_gen, generator.trainable_weights)
        opt_gen.apply_gradients(
            zip(grads, generator.trainable_weights)
        )
    ep_cnt.assign_add(1)
    cp_manager.save(checkpoint_number=epoch)

start epoch 0


100%|██████████| 3166/3166 [58:21<00:00,  1.11s/it]


start epoch 1


100%|██████████| 3166/3166 [57:49<00:00,  1.10s/it]


start epoch 2


 95%|█████████▌| 3012/3166 [54:49<02:48,  1.09s/it]

In [None]:
generator.save('saved_model/generator')
discriminator.save('saved_model/discriminator')