## Librerias

In [1]:
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential,Model
from tensorflow.keras.utils import plot_model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.datasets import fashion_mnist
import tensorflow as tf
import numpy as np
import os, io
import datetime
import matplotlib.pyplot as plt
import time

## Parametros de imagen y cgan

In [2]:
# parametros de la imagen
img_rows = 28
img_cols = 28
img_channels = 1
img_shape = (img_rows,img_cols,img_channels)
# espacio latente
latent_space = 100
# numero de clases
num_classes = 10
NUM_SAMPLES = 9

## Discriminator

In [3]:
def create_discriminator():
    # input imagen
    img_input = layers.Input(shape= img_shape)
    # input label 
    label_input = layers.Input(shape=(1,))
    hidden_layer = layers.Embedding(num_classes, 50)(label_input)
    hidden_layer = layers.Dense(img_rows*img_cols)(hidden_layer)
    hidden_layer = layers.Reshape(img_shape)(hidden_layer)
    # unir capas
    merge = layers.Concatenate()([img_input, hidden_layer])
    # downsample
    hidden_layer = layers.Conv2D(128, (3,3), strides=(2,2), padding='same')(merge)
    hidden_layer = layers.LeakyReLU(alpha=0.2)(hidden_layer)
    # downsample
    hidden_layer = layers.Conv2D(128, (3,3), strides=(2,2), padding='same')(hidden_layer)
    hidden_layer = layers.LeakyReLU(alpha=0.2)(hidden_layer)
    # flaten features
    hidden_layer = layers.Flatten()(hidden_layer)
    hidden_layer = layers.Dropout(0.4)(hidden_layer)
    # output
    output_layer = layers.Dense(1,activation='sigmoid')(hidden_layer)
    # define model
    model = Model([img_input,label_input],output_layer)
    return model

## Generator

In [4]:
def create_generator():
    # input z
    latent_input = layers.Input(shape=(latent_space,))
    hidden_layer = layers.Dense(128*7*7)(latent_input)
    hidden_layer = layers.LeakyReLU(alpha=0.2)(hidden_layer)        
    img_layer = layers.Reshape((7,7,128))(hidden_layer)
    # input label
    label_input = layers.Input(shape=(1,))
    hidden_layer = layers.Embedding(num_classes, 50)(label_input)
    hidden_layer = layers.Dense(7*7)(hidden_layer)
    hidden_layer = layers.Reshape((7,7,1))(hidden_layer)
    # unir capas
    merge = layers.Concatenate()([img_layer, hidden_layer])
    # upsample to 14x14
    hidden_layer = layers.Conv2DTranspose(128, (4,4), strides=(2,2), padding='same')(merge)
    hidden_layer = layers.LeakyReLU(alpha=0.2)(hidden_layer)
    # upsample to 28x28
    hidden_layer = layers.Conv2DTranspose(128, (4,4), strides=(2,2), padding='same')(hidden_layer)
    hidden_layer = layers.LeakyReLU(alpha=0.2)(hidden_layer)
    # output
    output_layer = layers.Conv2D(1, (7,7), activation='tanh', padding='same')(hidden_layer)
    # define model
    model = Model([latent_input,label_input],output_layer)
    return model  

## Define discriminator loss

In [5]:
def discriminator_loss(real_output, fake_output):
    # label real
    label_real = tf.ones_like(real_output)
    # smoothing class=1 to [0.7, 1.2]
    label_real = label_real - 0.3 + (tf.random.uniform((label_real.shape)) * 0.5)
    
    # label fake
    label_fake = tf.zeros_like(fake_output)
    # smoothing class=0 to [0.0, 0.3]
    label_fake = label_fake + (tf.random.uniform((label_fake.shape)) * 0.3)
    
    real_loss = cross_entropy(label_real, real_output)
    fake_loss = cross_entropy(label_fake, fake_output)
#     total_loss = real_loss + fake_loss
    return real_loss, fake_loss

## Define generator loss

In [6]:
def generator_loss(fake_output):
    # label real
    label_real = tf.ones_like(fake_output)
    # smoothing class=1 to [0.7, 1.2]
    label_real = label_real - 0.3 + (tf.random.uniform((label_real.shape)) * 0.5)
    return cross_entropy(label_real, fake_output)

## Define function train per step

In [7]:
@tf.function
def train_step(images_batch, labels):
    noise = tf.random.normal([BATCH_SIZE, latent_space])
    label_random = tf.random.uniform((BATCH_SIZE,), minval=0, maxval=num_classes, dtype=tf.dtypes.int32)
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator([noise, label_random], training=True)
        
        real_output = discriminator([images_batch, labels ], training=True)
        fake_output = discriminator([generated_images, label_random], training=True)
        
        gen_loss = generator_loss(fake_output)
        real_loss, fake_loss = discriminator_loss(real_output, fake_output)
        disc_loss = real_loss + fake_loss
        
    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
    
    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
    
    # sum loss per batch
    loss_generator(gen_loss)
    loss_discriminator_real(real_loss)
    loss_discriminator_fake(fake_loss)

## Define function training net

In [8]:
def train(dataset,epoch):
    for i in range(epoch):
        start = time.time()
        for image_batch, labels in dataset:
            train_step(image_batch, labels)

        with train_summary_writer.as_default():
            # write loss values
            tf.summary.scalar('loss generator', loss_generator.result(), step=i)
            tf.summary.scalar('loss discriminator real', loss_discriminator_real.result(), step=i)
            tf.summary.scalar('loss discriminator fake', loss_discriminator_fake.result(), step=i)
            
        template = 'Epoch {}, Loss generator: {}, Loss discriminator real: {}, Loss discriminator fake {}, Time: {}'
        print (template.format(i+1,
                         loss_generator.result(), 
                         loss_discriminator_real.result(),
                         loss_discriminator_fake.result(),
                         time.time()-start))
        # Save the model every 100 epochs
        if( (i + 1) % 2 == 0 or i == 0):
            checkpoint.save(file_prefix = checkpoint_prefix)
            images_example, labels_example = generated_samples_images()
            figure = image_grid(images_example, labels_example)
            with train_summary_writer.as_default():
                tf.summary.image("Epoch {}".format(i+1), plot_to_image(figure), step=i+1)
        
        loss_generator.reset_states()
        loss_discriminator_real.reset_states()
        loss_discriminator_fake.reset_states()

## Define functions to print images samples

In [9]:
def plot_to_image(figure):
    """Converts the matplotlib plot specified by 'figure' to a PNG image and
    returns it. The supplied figure is closed and inaccessible after this call."""
    # Save the plot to a PNG in memory.
    buf = io.BytesIO()
    plt.savefig(buf, format='png')
    # Closing the figure prevents it from being displayed directly inside
    # the notebook.
    plt.close(figure)
    buf.seek(0)
    # Convert PNG buffer to TF image
    image = tf.image.decode_png(buf.getvalue(), channels=4)
    # Add the batch dimension
    image = tf.expand_dims(image, 0)
    return image

def image_grid(samples, label_samples):
    """Return a 3x3 grid of the MNIST images as a matplotlib figure."""
    # Create a figure to contain the plot.
    figure = plt.figure(figsize=(12,12))
    for i in range(NUM_SAMPLES):
        # Start next subplot.
        plt.subplot(3, 3, i + 1, title=class_names[label_samples[i]])
        plt.xticks([])
        plt.yticks([])
        plt.grid(False)
        plt.imshow(samples[i], cmap=plt.cm.binary)
#     plt.savefig("{}.png".format(name), format='png')
    return figure

def generated_samples_images():
    samples = tf.random.normal([NUM_SAMPLES, latent_space])
    label_samples = tf.random.uniform((NUM_SAMPLES,), minval=0, maxval=num_classes, dtype=tf.dtypes.int32)
    generated_samples = generator([samples, label_samples], training=False)
    generated_samples = 127.5*generated_samples + 127.5
    generated_samples = tf.dtypes.cast(generated_samples, tf.int32)
    generated_samples = tf.reshape(generated_samples,[NUM_SAMPLES,img_rows,img_cols])
    return generated_samples, label_samples 

## Define function to load data

In [10]:
def load_samples():
    (x_train, y_train), (_, _) = fashion_mnist.load_data()    
    x_train = np.expand_dims(x_train, axis=-1)
    x_train = x_train.astype('float32')
    x_train = (x_train - 127.5) / 127.5
    # Names of the integer classes, i.e., 0 -> T-short/top, 1 -> Trouser, etc.
    class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 
        'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
    return x_train, y_train, class_names

## Define loss function

In [11]:
# This method returns a helper function to compute cross entropy loss
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

## Define metrics

In [12]:
loss_generator = tf.keras.metrics.Mean('loss_generator', dtype=tf.float32)
loss_discriminator_real = tf.keras.metrics.Mean('loss_discriminator_real', dtype=tf.float32)
loss_discriminator_fake = tf.keras.metrics.Mean('loss_discriminator_fake', dtype=tf.float32)

## Define optimizers

In [13]:
generator_optimizer = tf.keras.optimizers.Adam(learning_rate = 0.0002, beta_1 = 0.5) 
discriminator_optimizer = tf.keras.optimizers.SGD(learning_rate = 0.0005, momentum = 0.9) 

# Create cgan

In [14]:
generator = create_generator()
discriminator = create_discriminator()

## Save model

In [15]:
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)

## Save values curve loss

In [16]:
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
train_log_dir = 'logs/train/' + current_time
train_summary_writer = tf.summary.create_file_writer(train_log_dir)

## Init data

In [17]:
BUFFER_SIZE = 1000
BATCH_SIZE = 128
x_train, y_train, class_names = load_samples()
train_dataset = tf.data.Dataset.from_tensor_slices((x_train,y_train)).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

## Train cgan

In [None]:
NUM_EPOCH = 50
train(train_dataset, NUM_EPOCH)