# GANs con redes dense

In [None]:
# Importamos la librerías
import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, LeakyReLU
from tensorflow.keras import optimizers, backend
from tensorflow.keras.datasets import mnist

In [None]:
backend.clear_session()

## Red generadora

In [None]:
# Generador
generador = Sequential()
generador.add(Dense(100, input_shape = (20,)))
generador.add(LeakyReLU(alpha = 0.3))
generador.add(Dense(300))
generador.add(Dense(784, activation = 'tanh'))
generador.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 dense (Dense)               (None, 100)               2100      
                                                                 
 leaky_re_lu (LeakyReLU)     (None, 100)               0         
                                                                 
 dense_1 (Dense)             (None, 300)               30300     
                                                                 
 dense_2 (Dense)             (None, 784)               235984    
                                                                 
Total params: 268384 (1.02 MB)
Trainable params: 268384 (1.02 MB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________


## Red discriminativa

In [None]:
# Discriminador
discriminador = Sequential()
discriminador.add(Dense(300, input_shape = (784,)))
discriminador.add(LeakyReLU(alpha = 0.3))
discriminador.add(Dense(100))
discriminador.add(LeakyReLU(alpha = 0.3))
discriminador.add(Dense(1, activation = 'sigmoid'))
discriminador.summary()

Model: "sequential_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 dense_3 (Dense)             (None, 300)               235500    
                                                                 
 leaky_re_lu_1 (LeakyReLU)   (None, 300)               0         
                                                                 
 dense_4 (Dense)             (None, 100)               30100     
                                                                 
 leaky_re_lu_2 (LeakyReLU)   (None, 100)               0         
                                                                 
 dense_5 (Dense)             (None, 1)                 101       
                                                                 
Total params: 265701 (1.01 MB)
Trainable params: 265701 (1.01 MB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________


## Red GAN

In [None]:
# Como el discriminador se entrena por separado, se compila aparte
adam_1 = optimizers.Adam(learning_rate = 0.001)
discriminador.compile(loss = 'binary_crossentropy', optimizer = adam_1, metrics = ['accuracy'])

In [None]:
# Como el generador se entrena junto con la red discriminativa se crea un modelo nuevo para
# juntar las dos redes

# Se "congela" la red discriminativa
discriminador.trainable = False

# Definimos el modelo GAN con el generador y el discriminador
gan = Sequential()
gan.add(generador)
gan.add(discriminador)

In [None]:
gan.summary()

Model: "sequential_2"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 sequential (Sequential)     (None, 784)               268384    
                                                                 
 sequential_1 (Sequential)   (None, 1)                 265701    
                                                                 
Total params: 534085 (2.04 MB)
Trainable params: 268384 (1.02 MB)
Non-trainable params: 265701 (1.01 MB)
_________________________________________________________________


In [None]:
# Compilamos el modelo GAN
adam_2 = optimizers.Adam(learning_rate = 0.001)
gan.compile(loss = 'binary_crossentropy', optimizer = adam_2)