In [1]:
import tensorflow as tf
from tensorflow.keras import Input, Model
from tensorflow.keras.layers import Dense, Reshape, Embedding, Concatenate, Conv2DTranspose, Conv2D, Flatten, Dropout, ReLU, LeakyReLU
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator

In [None]:
def build_cgan_generator(latent_dim, n_cats, in_shape=(7, 7)):
    # Label Inputs
    in_label = Input(shape=(1,), name='Generator-Label-Input-Layer')
    lbls = Embedding(n_cats, 50, name='Generator-Label-Embedding-Layer')(in_label)

    n_nodes = in_shape[0] * in_shape[1]
    lbls = Dense(n_nodes, name='Generator-Label-Dense-Layer')(lbls)
    lbls = Reshape((in_shape[0], in_shape[1], 1), name='Generator-Label-Reshape-Layer')(lbls)

    # Generator Inputs (latent vector)
    in_latent = Input(shape=latent_dim, name='Generator-Latent-Input-Layer')

    n_nodes = 7 * 7 * 128
    g = Dense(n_nodes, name='Generator-Foundation-Layer')(in_latent)
    g = ReLU(name='Generator-Foundation-Layer-Activation-1')(g)
    g = Reshape((in_shape[0], in_shape[1], 128), name='Generator-Foundation-Layer-Reshape-1')(g)

    # Combine both inputs
    concat = Concatenate(name='Generator-Combine-Layer')([g, lbls])

    # Hidden Layer 1
    g = Conv2DTranspose(filters=128, kernel_size=(4, 4), strides=(2, 2), padding='same', name='Generator-Hidden-Layer-1')(concat)
    g = ReLU(name='Generator-Hidden-Layer-Activation-1')(g)

    # Hidden Layer 2
    g = Conv2DTranspose(filters=128, kernel_size=(4, 4), strides=(2, 2), padding='same', name='Generator-Hidden-Layer-2')(g)
    g = ReLU(name='Generator-Hidden-Layer-Activation-2')(g)

    # Output Layer
    output_layer = Conv2D(filters=1, kernel_size=(7, 7), activation='tanh', padding='same', name='Generator-Output-Layer')(g)

    # Define model
    model = Model([in_latent, in_label], output_layer, name='Generator')
    return model

def build_cgan_discriminator(img_shape, n_cats):
    # Image input
    in_img = Input(shape=img_shape, name='Discriminator-Image-Input-Layer')

    # Label input
    in_label = Input(shape=(1,), name='Discriminator-Label-Input-Layer')
    lbls = Embedding(n_cats, 50, name='Discriminator-Label-Embedding-Layer')(in_label)
    n_nodes = img_shape[0] * img_shape[1]
    lbls = Dense(n_nodes, name='Discriminator-Label-Dense-Layer')(lbls)
    lbls = Reshape(img_shape, name='Discriminator-Label-Reshape-Layer')(lbls)

    # Concatenate image and label inputs
    concat = Concatenate(name='Discriminator-Combine-Layer')([in_img, lbls])
    # Hidden Layer 1
    x = Conv2D(32, kernel_size=4, strides=2, padding='same', name='Discriminator-Hidden-Layer-1')(concat)
    x = LeakyReLU(alpha=0.2, name='Discriminator-Hidden-Layer-Activation-1')(x)

    # Hidden Layer 2
    x = Conv2D(64, kernel_size=4, strides=2, padding='same', name='Discriminator-Hidden-Layer-2')(x)
    x = LeakyReLU(alpha=0.2, name='Discriminator-Hidden-Layer-Activation-2')(x)

    # Flatten and Output Layer
    x = Flatten(name='Discriminator-Flatten-Layer')(x)
    output_layer = Dense(1, activation='sigmoid', name='Discriminator-Output-Layer')(x)

    # Define model
    model = Model([in_img, in_label], output_layer, name='Discriminator')
    return model

##Define loss functions and optimaziers

In [None]:
def create_dataloader(image_dir, img_shape, batch_size):
    transform = ImageDataGenerator(
        rescale=1. / 255,
        zoom_range=0.1,
        width_shift_range=0.1,
        height_shift_range=0.1
    )

    dataloader = transform.flow_from_directory(
        image_dir,
        target_size=img_shape[:2],
        color_mode='grayscale',
        batch_size=batch_size,
        class_mode=None
    )

    return dataloader

# Parámetros
image_dir = 'data'
img_shape = (64, 64, 1)
latent_dim = 100
batch_size = 128
num_epochs = 100
n_cats = 10

# Crear modelos y optimizadores
generator = build_cgan_generator(latent_dim, n_cats)
discriminator = build_cgan_discriminator(img_shape, n_cats)

generator_optimizer = Adam(lr=0.0002, beta_1=0.5)
discriminator_optimizer = Adam(lr=0.0002, beta_1=0.5)

##Create dataloader

In [None]:
# Crear dataloader
dataloader = create_dataloader(image_dir, img_shape, batch_size)

##GAN Training

In [None]:
# Ciclo de entrenamiento
for epoch in range(num_epochs):
    for real_imgs in dataloader:
        labels = tf.convert_to_tensor([int(folder.split('/')[-1]) for folder in dataloader.filepaths])  # Obtiene las etiquetas a partir de los nombres de las subcarpetas
        labels = tf.reshape(labels, (-1, 1))

        # Actualiza el discriminador
        real_labels = tf.ones((real_imgs.shape[0], 1))
        fake_imgs = generator([tf.random.normal((real_imgs.shape[0], latent_dim)), labels])
        fake_labels = tf.zeros((real_imgs.shape[0], 1))

        with tf.GradientTape() as tape:
            real_preds = discriminator([real_imgs, labels])
            fake_preds = discriminator([fake_imgs, labels])

            real_loss = tf.keras.losses.binary_crossentropy(real_labels, real_preds)
            fake_loss = tf.keras.losses.binary_crossentropy(fake_labels, fake_preds)
            total_loss = (real_loss + fake_loss) * 0.5

        disc_grads = tape.gradient(total_loss, discriminator.trainable_weights)
        discriminator_optimizer.apply_gradients(zip(disc_grads, discriminator.trainable_weights))

        # Actualiza el generador
        sampled_labels = tf.random.uniform((real_imgs.shape[0], 1), minval=0, maxval=n_cats, dtype=tf.int32)

        with tf.GradientTape() as tape:
            fake_imgs = generator([tf.random.normal((real_imgs.shape[0], latent_dim)), sampled_labels])
            fake_preds = discriminator([fake_imgs, sampled_labels])

            gen_loss = tf.keras.losses.binary_crossentropy(tf.ones_like(fake_preds), fake_preds)

        gen_grads = tape.gradient(gen_loss, generator.trainable_weights)
        generator_optimizer.apply_gradients(zip(gen_grads, generator.trainable_weights))

    # Imprime el progreso del entrenamiento
    print(f'Epoch: {epoch + 1}, Generator Loss: {gen_loss.numpy().mean()}, Discriminator Loss: {total_loss.numpy().mean()}')
