<a href="https://colab.research.google.com/github/whukam/neuralNetworkProject/blob/main/GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [21]:
# Bring in the sequential api for the generator and discriminator
import tensorflow as tf
from tensorflow.keras.models import Sequential
# Bring in the layers for the neural network
from tensorflow.keras.layers import Conv2DTranspose, Dense, Flatten, Reshape, LeakyReLU, Dropout, UpSampling2D, BatchNormalization

In [12]:

def generator():
    model = Sequential()
    model.add(Dense(64*64*256, use_bias=False, input_shape=(100,)))
    model.add(BatchNormalization())
    model.add(LeakyReLU())

    model.add(Reshape((64, 64, 256))) 

    model.add(Conv2DTranspose(128, (9, 9), strides=(1, 1), padding='same', use_bias=False))
    model.add(BatchNormalization())
    model.add(LeakyReLU())

    model.add(Conv2DTranspose(64, (9, 9), strides=(2, 2), padding='same', use_bias=False))
    model.add(BatchNormalization())
    model.add(LeakyReLU())

    model.add(Conv2DTranspose(1, (9, 9), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))

    return model
     


In [13]:
def discriminator(): 
    model = Sequential()
    
    # First Conv Block
    model.add(Conv2DTranspose(32, 5, input_shape = (128,128,1)))
    model.add(LeakyReLU(0.2))
    model.add(Dropout(0.4))
    
    # Second Conv Block
    model.add(Conv2DTranspose(64, 5))
    model.add(LeakyReLU(0.2))
    model.add(Dropout(0.4))
    
    # Third Conv Block
    model.add(Conv2DTranspose(128, 5))
    model.add(LeakyReLU(0.2))
    model.add(Dropout(0.4))
    
    # Fourth Conv Block
    model.add(Conv2DTranspose(256, 5))
    model.add(LeakyReLU(0.2))
    model.add(Dropout(0.4))
    
    # Flatten then pass to dense layer
    model.add(Flatten())
    model.add(Dropout(0.4))
    model.add(Dense(1, activation='sigmoid'))
    
    return model 

In [14]:
generator_s2h = generator()
generator_h2s = generator()

discriminator_s2h = discriminator()
discriminator_h2s = discriminator()

In [15]:
generator_s2h.summary()

Model: "sequential_4"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 dense_4 (Dense)             (None, 1048576)           104857600 
                                                                 
 batch_normalization_6 (Batc  (None, 1048576)          4194304   
 hNormalization)                                                 
                                                                 
 leaky_re_lu_14 (LeakyReLU)  (None, 1048576)           0         
                                                                 
 reshape_2 (Reshape)         (None, 64, 64, 256)       0         
                                                                 
 conv2d_transpose_14 (Conv2D  (None, 64, 64, 128)      2654208   
 Transpose)                                                      
                                                                 
 batch_normalization_7 (Batc  (None, 64, 64, 128)     

In [16]:
discriminator_s2h.summary()

Model: "sequential_6"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d_transpose_20 (Conv2D  (None, 132, 132, 32)     832       
 Transpose)                                                      
                                                                 
 leaky_re_lu_20 (LeakyReLU)  (None, 132, 132, 32)      0         
                                                                 
 dropout_10 (Dropout)        (None, 132, 132, 32)      0         
                                                                 
 conv2d_transpose_21 (Conv2D  (None, 136, 136, 64)     51264     
 Transpose)                                                      
                                                                 
 leaky_re_lu_21 (LeakyReLU)  (None, 136, 136, 64)      0         
                                                                 
 dropout_11 (Dropout)        (None, 136, 136, 64)     

In [17]:
# Adam is going to be the optimizer for both
from tensorflow.keras.optimizers import Adam
# Binary cross entropy is going to be the loss for both 
from tensorflow.keras.losses import BinaryCrossentropy

In [26]:
opt_obj = Adam(2e-4, beta_1=0.5) 
loss_obj = BinaryCrossentropy(from_logits=True)

In [19]:
# Importing the base model class to subclass our training step 
from tensorflow.keras.models import Model

In [37]:

class cycleGAN(Model): 
    def __init__(self, generator_s2h, discriminator_s2h, generator_h2s, discriminator_h2s, *args, **kwargs):
        # Pass through args and kwargs to base class 
        super().__init__(*args, **kwargs)
        
        self.LAMBDA = 15

        # Create attributes for gen and disc
        self.generator_g = generator_s2h
        self.generator_f = generator_h2s

        self.discriminator_x = discriminator_s2h
        self.discriminator_y = discriminator_h2s

    def discriminator_loss(real, generated):
      real_loss = loss_obj(tf.ones_like(real), real)
      
      generated_loss = loss_obj(tf.zeros_like(generated), generated)
      
      total_disc_loss = real_loss + generated_loss
      
      return total_disc_loss * 0.5
    
    def generator_loss(generated):
      return loss_obj(tf.ones_like(generated), generated)

    
    def calc_cycle_loss(real_image, cycled_image):
      loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))
      return LAMBDA * loss1
    
    def identity_loss(real_image, same_image):
      loss = tf.reduce_mean(tf.abs(real_image - same_image))
      return LAMBDA * 0.5 * loss



    def compile(self, opt_obj, loss_obj, **kwargs): 
        # Compile with base class
        super().compile(*args, **kwargs)
        
        # Create attributes for losses and optimizers
        self.generator_g_optimizer = opt_obj
        self.discriminator_x_optimizer = opt_obj

        self.generator_f_optimizer = opt_obj
        self.discriminator_y_optimizer = opt_obj


    def train_step(real_x, real_y):
      # persistent is set to True because the tape is used more than
      # once to calculate the gradients.
      
      
      with tf.GradientTape(persistent=True) as tape:
        # Generator G translates X -> Y
        # Generator F translates Y -> X.
        
        
        fake_y = generator_g(real_x, training=True)
        cycled_x = generator_f(fake_y, training=True)
          
        fake_x = generator_f(real_y, training=True)
        cycled_y = generator_g(fake_x, training=True)
        
        # same_x and same_y are used for identity loss.
        same_x = generator_f(real_x, training=True)
        same_y = generator_g(real_y, training=True)
        
        disc_real_x = discriminator_x(real_x, training=True)
        disc_real_y = discriminator_y(real_y, training=True)
        
        disc_fake_x = discriminator_x(fake_x, training=True)
        disc_fake_y = discriminator_y(fake_y, training=True)
        
        # calculate the loss
        gen_g_loss = generator_loss(disc_fake_y)
        gen_f_loss = generator_loss(disc_fake_x)
        
        total_cycle_loss = calc_cycle_loss(real_x, cycled_x) + calc_cycle_loss(real_y, cycled_y)
        
        
        # Total generator loss = adversarial loss + cycle loss
        total_gen_g_loss = gen_g_loss + total_cycle_loss + identity_loss(real_y, same_y)
        total_gen_f_loss = gen_f_loss + total_cycle_loss + identity_loss(real_x, same_x)
        
        disc_x_loss = discriminator_loss(disc_real_x, disc_fake_x)
        disc_y_loss = discriminator_loss(disc_real_y, disc_fake_y)
        
        
        # Calculate the gradients for generator and discriminator
        generator_g_gradients = tape.gradient(total_gen_g_loss, generator_g.trainable_variables)
        
        generator_f_gradients = tape.gradient(total_gen_f_loss, generator_f.trainable_variables)
        
        
        discriminator_x_gradients = tape.gradient(disc_x_loss, discriminator_x.trainable_variables)
        
        discriminator_y_gradients = tape.gradient(disc_y_loss, discriminator_y.trainable_variables)
        
        
        # Apply the gradients to the optimizer
        generator_g_optimizer.apply_gradients(zip(generator_g_gradients, generator_g.trainable_variables))
        
        generator_f_optimizer.apply_gradients(zip(generator_f_gradients, generator_f.trainable_variables))
        
        
        discriminator_x_optimizer.apply_gradients(zip(discriminator_x_gradients, discriminator_x.trainable_variables))
        
        discriminator_y_optimizer.apply_gradients(zip(discriminator_y_gradients, discriminator_y.trainable_variables))


    




    

In [None]:
# Create instance of subclassed model
GAN_model = cycleGAN(generator, discriminator)

In [None]:
# Compile the model
GAN_model.compile(opt_obj, loss_obj)

In [None]:
hist = GAN_model.fit(data, epochs=20)

In [None]:
plt.suptitle('Loss')
plt.plot(hist.history['d_loss'], label='d_loss')
plt.plot(hist.history['g_loss'], label='g_loss')
plt.legend()
plt.show()

In [None]:
imgs = generator.predict(human_data[23]))
plt.imshow(img)