<a href="https://colab.research.google.com/github/takanto/CNN_GAN_512/blob/main/CNN_GAN_512_image_Notebook.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## **CNN GAN for 512 pixels**

This GAN model is traditional CNN GAN for 512 pixels. This model is not trained so that GAN can generate any kind of images depending on your input training data. Thus, user needs to train the model by first preparing roughly 50,000 images with (512, 512). When training, please not forget to run with TPU to speed up the training process, and divide training into multiple times if necessary. (When dividing the training, it is recommended that you keep track of epochs by adding number of epochs to the names of the model you save) 

## **Code**

### **Preparation**

Run the code below before any operation.

・Import necessary libraries

In [None]:
import tensorflow as tf

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import sys, os

from tensorflow.keras.layers import Input, Dense, Flatten, Conv2DTranspose, MaxPooling2D, Dropout, BatchNormalization, Reshape, LeakyReLU, Conv2D
from tensorflow.keras.applications.vgg16 import VGG16 as PretrainedModel, preprocess_input
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing import image

・ Setting up TPU environment

In [None]:
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu="")
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
print("all devices:", tf.config.list_logical_devices("TPU"))

In [None]:
strategy = tf.distribute.TPUStrategy(resolver)

・ Loading your training data. Please upload your training images. Format needs to be 512 pixels. 

In [None]:
from PIL import Image
import os, numpy as np
folder = 'name of your file containing training images'

read = lambda imname: np.asarray(Image.open(imname).convert("RGB"))

ims = [read(os.path.join(folder, filename)) for filename in os.listdir(folder)]
im_array = np.array(ims, dtype='uint8')
im_array = (im_array / 255.0)*2 - 1

・　This function create plot of 25 images generated by generator through training.

In [None]:
def sample_images(epoch):
  rows, cols = 5, 5
  noise = np.random.randn(rows*cols, latent_dim)
  imgs = generator.predict(noise)

  # rescaling image from -1 to 1 to 0 to 1
  imgs = imgs * 0.5 + 0.5

  fig, axs = plt.subplots(rows, cols)
  idx = 0
  for i in range(rows):
    for j in range(cols):
      axs[i,j].imshow(imgs[idx].reshape(512,512,3))
      axs[i,j].axis("off")
      idx += 1
  fig.savefig("cifar_gan_images/%d.png" % epoch)
  plt.close()

### **First time running**

Run below when it is first time for you to train your model.

・ Define generator. It uses convolutional transpose layer for upsampling. 

In [None]:
atent_dim = 100
def build_generator(latent_dim):
  
  i = Input(shape=(latent_dim,))
  x = Dense(256*4*4, activation = LeakyReLU(alpha=0.2)) (i)
  x = Reshape((4,4,256)) (x)

  x = Conv2DTranspose(128, (5,5), strides = 4, padding = "same", activation = LeakyReLU(alpha=0.2))(x)
  x = Conv2DTranspose(128, (5,5), strides = 4, padding = "same", activation = LeakyReLU(alpha=0.2))(x)
  x = Conv2DTranspose(128, (5,5), strides = 4, padding = "same", activation = LeakyReLU(alpha=0.2))(x)
  x = Conv2DTranspose(3, (3,3), padding = "same", activation = "tanh")(x)

  model = Model(i,x)
  return model

・　Check the output shape in the below summary. It should be (None, 512, 512, 3).

In [None]:
generator_check = build_generator(latent_dim)
generator_check.summary()

・ Define discriminator. It is a series of convolutional layer to discriminate generator images and training images.

In [None]:
def build_discriminator(image_size):
  i = Input(shape=image_size)
  x = Conv2D(64,(5,5), strides=4, padding = "same", activation=LeakyReLU(alpha=0.2))(i)
  x = Conv2D(128,(5,5), strides=4, padding = "same", activation=LeakyReLU(alpha=0.2))(x)
  x = Conv2D(128,(5,5), strides=4, padding = "same", activation=LeakyReLU(alpha=0.2))(x)
  x = Conv2D(256,(5,5), strides=4, padding = "same", activation=LeakyReLU(alpha=0.2))(x)
  x = Flatten()(x)
  x = Dropout(0.4)(x)
  x = Dense(1, activation="sigmoid")(x)

  model = Model(i,x)

  return model

・ Check the output shape in the below summary. It sould be (None, 1).

In [None]:
discriminator_check = build_discriminator((512,512,3))
discriminator_check.summary()

・　Basic set up

In [None]:
batch_size = 128
ones = np.ones(batch_size)
zeros = np.zeros(batch_size)

d_losses = []
g_losses = []

if not os.path.exists("gan_images"):
  os.makedirs("gan_images")

・　below is actual training process. Change the epochs and sample period for your needs.

In [None]:
with strategy.scope():
  discriminator = build_discriminator([512,512,3])
  discriminator.compile(optimizer = Adam(0.0002, 0.5),
                      loss = "binary_crossentropy",
                      metrics = ["accuracy"])

  generator = build_generator(latent_dim)

  z = Input(shape=(latent_dim,))

  img = generator(z)

  discriminator.trainable = False

  fake_pred = discriminator(img)

  combined_model = Model(z, fake_pred)

  combined_model.compile(optimizer = Adam(0.0002, 0.5),
                       loss = "binary_crossentropy")

epochs = 10000
sample_period = 100

for epoch in range(epochs):

  ## train discriminator ##

  # get random batches of real images
  idx = np.random.randint(0, im_array.shape[0], batch_size)
  real_imgs = im_array[idx]

  # generate fake images
  noise = np.random.randn(batch_size, latent_dim)
  fake_imgs = generator.predict(noise)

  # use train_on_batch to train discriminator
  d_loss_real, d_acc_real = discriminator.train_on_batch(real_imgs, ones)
  d_loss_fake, d_acc_fake = discriminator.train_on_batch(fake_imgs, zeros)
  d_loss = (d_loss_real + d_loss_fake) / 2
  d_acc = (d_acc_real + d_acc_fake) / 2

  ## train generator ##

  # by calling combined model and optimize for the freezed discriminator values to ones
  # we can optimize generator to generate real like images
  noise = np.random.randn(batch_size, latent_dim)
  g_loss = combined_model.train_on_batch(noise, ones)

  d_losses.append(d_loss)
  g_losses.append(g_loss)

  if epoch % 10 == 0:
    print(f"epoch: {epoch+1} / {epochs}, d_loss: {d_loss:.2f}, d_acc: {d_acc:.2f}, g_loss: {g_loss:.2f}")

  if epoch % sample_period == 0:
    sample_images(epoch)

・　after training, same the models and weights. Make sure to download them on your hard drive.

In [None]:
generator.save("generator_model.h5")
discriminator.save("discriminator_model.h5")

In [None]:
generator_weights = generator.get_weights()
discriminator_weights = discriminator.get_weights()
model_dir = "model_dir"
if not os.path.exists(model_dir):
  os.makedirs(model_dir)
np.save(os.path.join(model_dir, 'gan_generator_weights'), generator_weights)
np.save(os.path.join(model_dir, 'gan_discriminator_weights'), discriminator_weights)

・ Run below when you want the plot of losses of discriminator and generator. 

In [None]:
plt.plot(d_losses, label = "d_losses")
plt.plot(g_losses, label = "g_losses")
plt.legend()

### **Additional training of your model**

When you already have your model and weights trained and want to train your model even more, below is the code you need to run.

・　Loading in your model and weights. (when you set different name, please reflect changes here too.)

In [None]:
generator_loaded = tf.keras.models.load_model("generator_model.h5")
discriminator_loaded = tf.keras.models.load_model("discriminator_model.h5")

In [None]:
generator_weights = np.load("model_dir/gan_generator_weights.npy", allow_pickle=True)
discriminator_weights = np.load("model_dir/gan_discriminator_weights.npy", allow_pickle=True)

In [None]:
generator_loaded.set_weights(generator_weights)
discriminator_loaded.set_weights(discriminator_weights)

・ Check your models

In [None]:
# generator check
noise = np.random.randn(1,latent_dim)
img = generator_loaded.predict(noise)
img = img * 0.5 + 0.5
plt.imshow(img.reshape(512,512,3))

In [None]:
# discriminator check
img = (img-0.5) / 0.5
discriminator_loaded.predict(img)

・ Training

In [None]:
batch_size = 128
ones = np.ones(batch_size)
zeros = np.zeros(batch_size)

d_losses = []
g_losses = []

if not os.path.exists("gan_images"):
  os.makedirs("gan_images")

In [None]:
with strategy.scope():
  discriminator_loaded.compile(optimizer = Adam(0.0002, 0.5),
                      loss = "binary_crossentropy",
                      metrics = ["accuracy"])

  z = Input(shape=(latent_dim,))

  img = generator_loaded(z)

  discriminator_loaded.trainable = False

  fake_pred = discriminator_loaded(img)

  combined_model = Model(z, fake_pred)

  combined_model.compile(optimizer = Adam(0.0002, 0.5),
                       loss = "binary_crossentropy")

batch_size = 128
epochs = 10000
sample_period = 100

for epoch in range(epochs):

  ## train discriminator ##

  # get random batches of real images
  idx = np.random.randint(0, im_array.shape[0], batch_size)
  real_imgs = im_array[idx]

  # generate fake images
  noise = np.random.randn(batch_size, latent_dim)
  fake_imgs = generator_loaded.predict(noise)

  # use train_on_batch to train discriminator
  d_loss_real, d_acc_real = discriminator_loaded.train_on_batch(real_imgs, ones)
  d_loss_fake, d_acc_fake = discriminator_loaded.train_on_batch(fake_imgs, zeros)
  d_loss = (d_loss_real + d_loss_fake) / 2
  d_acc = (d_acc_real + d_acc_fake) / 2

  ## train generator ##

  # by calling combined model and optimize for the freezed discriminator values to ones
  # we can optimize generator to generate real like images
  noise = np.random.randn(batch_size, latent_dim)
  g_loss = combined_model.train_on_batch(noise, ones)

  d_losses.append(d_loss)
  g_losses.append(g_loss)

  if epoch % 10 == 0:
    print(f"epoch: {epoch+1} / {epochs}, d_loss: {d_loss:.2f}, d_acc: {d_acc:.2f}, g_loss: {g_loss:.2f}")

  if epoch % sample_period == 0:
    sample_images(epoch)

・　after training, same the models and weights. Make sure to download them on your hard drive.

In [None]:
generator_loaded.save("generator_model.h5")
discriminator_loaded.save("discriminator_model.h5")

In [None]:
generator_weights = generator_loaded.get_weights()
discriminator_weights = discriminator_loaded.get_weights()
model_dir = "model_dir"
if not os.path.exists(model_dir):
  os.makedirs(model_dir)
np.save(os.path.join(model_dir, 'gan_generator_weights'), generator_weights)
np.save(os.path.join(model_dir, 'gan_discriminator_weights'), discriminator_weights)

・ Run below when you want the plot of losses of discriminator and generator. 

In [None]:
plt.plot(d_losses, label = "d_losses")
plt.plot(g_losses, label = "g_losses")
plt.legend()

## **TF Lite file for mobile app**

・ If you have saved the trained model and haven't done additional training, run below

In [None]:
generator_loaded = tf.keras.models.load_model("generator_model.h5")
generator_weights = np.load("model_dir/gan_generator_weights.npy", allow_pickle=True)

generator_loaded.set_weights(generator_weights)

・　Below convert the model into tflite file which can be used for mobile app development

In [None]:
converter = tf.lite.TFLiteConverter.from_keras_model(generator_loaded) ## change the name according to the current name of the generator

tflite_model = converter.convert()

with open("generator_model.tflite", "wb") as f:
  f.write(tfile_model)