### Dependencies

In [1]:
!pip install matplotlib tensorflow tensorflow_addons tensorflow_datasets imageio





### Setup

In [1]:
import glob
import imageio
import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
from tensorflow.keras import layers
import time
from tensorflow.keras import Model
import tensorflow as tf

from IPython import display

### Dataset

In [2]:
train_ds = tf.keras.utils.image_dataset_from_directory(
#   "/home/tony/TO_BE_REMOVED/celeba_data/imgs/",
  "/home/tony/TO_BE_REMOVED/mnist_ds/mnist_jpg/training",
  seed=123,
  image_size=(32, 32),
  batch_size=16)

Found 60000 files belonging to 10 classes.


In [3]:
# for i_b, l_b in train_ds:
#     print(i_b.shape)
#     print(tf.image.rgb_to_grayscale(i_b).shape)

#### Note: 

Images should be normalized to [-1,1] ***(Done in Arch)***

### Generator network

#### Note:

Modify the network size for deployment

In [9]:
class Generator(Model):

    def __init__(self, noise_dim, image_shape, num_channel):
        super().__init__()
        
        assert len(image_shape) == 2
        assert image_shape[0]%8 == 0
        assert image_shape[1]%8 == 0
        
        self.noise_dim = noise_dim
        self.image_shape = image_shape
        self.num_channel = num_channel
        self.kernel_size = int(min(max(min(image_shape[0]/8.0-3.0, image_shape[1]/8.0-3.0), 3.0), 5.0))

        self.lr_d = layers.ReLU()
        self.lr_c1 = layers.ReLU()
        self.lr_c2 = layers.ReLU()
        self.lr_c3 = layers.ReLU()
        
        self.init_dense = layers.Dense(image_shape[0]/8.0*image_shape[1]/8.0*128,
                               use_bias=False, input_shape=(self.noise_dim,))
        
        self.init_reshape = layers.Reshape((int(image_shape[0]/8.0), int(image_shape[1]/8.0), 128))
        
        self.conv2dT1 = layers.Conv2DTranspose(128, (self.kernel_size, self.kernel_size),
                                               strides=(1, 1), padding='same')
        self.conv2dT2 = layers.Conv2DTranspose(64, (self.kernel_size, self.kernel_size),
                                               strides=(2, 2), padding='same')
        self.conv2dT3 = layers.Conv2DTranspose(32, (self.kernel_size, self.kernel_size),
                                               strides=(2, 2), padding='same')
        self.conv2dTactv = layers.Conv2DTranspose(self.num_channel, (self.kernel_size, self.kernel_size),
                                               strides=(2, 2), padding='same', activation='tanh')

    def call(self, noise_vec):

        init_vec = tf.squeeze(self.lr_d(self.init_dense(noise_vec)))
        
#         print(init_vec.shape)
        
        reshaped = self.init_reshape(init_vec)
        
#         print(reshaped.shape)
        
        convt1 = self.lr_c1(self.conv2dT1(reshaped))

#         print(convt1.shape)
        
        convt2 = self.lr_c2(self.conv2dT2(convt1))
        
#         print(convt2.shape)
                         
        convt3 = self.lr_c3(self.conv2dT3(convt2))
        
#         print(convt3.shape)
            
        out = self.conv2dTactv(convt3)
        
#         print(out.shape)

        return out

#### Testing

In [10]:
g1 = Generator(10, (32, 32), 3)

In [11]:
g1.kernel_size

3

In [12]:
noise_input = tf.random.normal((5, 10))
print(noise_input.shape)
pics1 = g1(tf.expand_dims(noise_input, 0))
print(pics1.shape)
# plt.imshow(pics1[-1, :, :, :], cmap='gray')

(5, 10)
(5, 32, 32, 3)


In [13]:
g1.summary()

Model: "generator_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 re_lu_4 (ReLU)              multiple                  0         
                                                                 
 re_lu_5 (ReLU)              multiple                  0         
                                                                 
 re_lu_6 (ReLU)              multiple                  0         
                                                                 
 re_lu_7 (ReLU)              multiple                  0         
                                                                 
 dense_1 (Dense)             multiple                  20480     
                                                                 
 reshape_1 (Reshape)         multiple                  0         
                                                                 
 conv2d_transpose_4 (Conv2DT  multiple                 

### Discriminator

In [14]:
class Discriminator(Model):

    def __init__(self, image_shape, num_channel):
        super().__init__()
        
        assert len(image_shape) == 2
        assert image_shape[0]%8 == 0
        assert image_shape[1]%8 == 0
        
        self.image_shape = image_shape
        self.num_channel = num_channel
        self.kernel_size = int(min(max(min(image_shape[0]/8.0-3.0, image_shape[1]/8.0-3.0), 3.0), 5.0))

        self.lr_c1 = layers.LeakyReLU()
        self.lr_c2 = layers.LeakyReLU()
        self.lr_c3 = layers.LeakyReLU()
        self.flatten = layers.Flatten()
        
        self.conv2d1 = layers.Conv2D(32, (self.kernel_size, self.kernel_size),
                                        strides=(2, 2), padding='same',
                                        input_shape=(None, self.image_shape[0],
                                        self.image_shape[1], self.num_channel))
        self.conv2d2 = layers.Conv2D(64, (self.kernel_size, self.kernel_size),
                                               strides=(2, 2), padding='same')
        self.conv2d3 = layers.Conv2D(128, (self.kernel_size, self.kernel_size),
                                               strides=(2, 2), padding='same')
        self.dense_actv = layers.Dense(256,
                                      )
#                                        activation="sigmoid")
        
    def call(self, img_input):
        
        conv1 = self.lr_c1(self.conv2d1(img_input))
        
#         print(conv1.shape)
        
        conv2 = self.lr_c2(self.conv2d2(conv1))
        
#         print(conv2.shape)
                         
        conv3 = self.lr_c3(self.conv2d3(conv2))
        
#         print(conv3.shape)
        
        flat = self.flatten(conv3)
        
#         print(flat.shape)
        
        out = tf.squeeze(self.dense_actv(flat))
        
#         print(out.shape)

        return out

#### Testing

In [15]:
d1 = Discriminator((32,32), 3)
g2 = Generator(10, (32,32), 3)
d1.kernel_size

3

In [16]:
noise_input = tf.random.normal((5, 10))
pics2 = g1(tf.expand_dims(noise_input, 0))
# plt.imshow(pics2[-1, :, :, 0], cmap='gray')
print(pics2.shape)

(5, 32, 32, 3)


In [17]:
deci = d1(pics2)
deci

<tf.Tensor: shape=(5, 256), dtype=float32, numpy=
array([[-1.0242338e-03,  7.4563571e-04,  4.4459841e-05, ...,
        -9.2679198e-05,  1.0957530e-03,  3.7814977e-04],
       [-9.4595744e-04,  5.6333444e-04, -1.8231114e-04, ...,
        -2.5268120e-04,  7.3711824e-04, -1.2429734e-04],
       [-1.2060391e-03, -5.1791896e-04, -1.0303807e-03, ...,
        -2.6774278e-04,  8.2628761e-04,  5.1804011e-05],
       [-9.8223228e-04,  4.9945875e-04, -2.1770452e-04, ...,
        -6.2967301e-05,  4.6027196e-04,  4.6945247e-04],
       [-1.0700801e-03,  1.0577107e-03,  7.7594472e-05, ...,
        -4.8216351e-04,  2.7823698e-04,  1.7998932e-04]], dtype=float32)>

In [18]:
d1.summary()

Model: "discriminator"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 leaky_re_lu (LeakyReLU)     multiple                  0         
                                                                 
 leaky_re_lu_1 (LeakyReLU)   multiple                  0         
                                                                 
 leaky_re_lu_2 (LeakyReLU)   multiple                  0         
                                                                 
 flatten (Flatten)           multiple                  0         
                                                                 
 conv2d (Conv2D)             multiple                  896       
                                                                 
 conv2d_1 (Conv2D)           multiple                  18496     
                                                                 
 conv2d_2 (Conv2D)           multiple                

### CramerDCGAN

In [104]:
EPSILON = 1e-16

class DCGAN:
    
    def __init__(self, dataset_path, image_shape, num_channel, noise_latent_dim, disc_update_multi=5, 
                 batch_size=128, lr=3e-4, gp_lam = 10.0):
        assert len(image_shape) == 2
        assert image_shape[0]%8 == 0
        assert image_shape[1]%8 == 0
        
        self.image_shape = image_shape
        self.num_channel = num_channel
        self.noise_latent_dim = noise_latent_dim
        self.batch_size, self.gp_lam = batch_size, gp_lam
        self.disc_update_multi = disc_update_multi
        self.num_img_prog_monit = 16
        
        self.dataset = tf.keras.utils.image_dataset_from_directory(
                              dataset_path,
                              seed=123,
                              image_size=self.image_shape,
                              batch_size=self.batch_size)
        # NOTE: Dataset must be processed differently for different source and applications
        
        self.g = Generator(self.noise_latent_dim, self.image_shape, self.num_channel)
        self.d = Discriminator(self.image_shape, self.num_channel)
        
        self.g_opt = tf.keras.optimizers.Adam(lr)
        self.d_opt = tf.keras.optimizers.Adam(lr)
        
        self.g_seed = tf.random.normal((self.num_img_prog_monit, self.noise_latent_dim))

    def cramer_loss(self, d_x_data, d_g_z_1, d_g_z_2, x_it, update_gen=True):
        
        crit_r = tf.math.add(tf.math.sqrt(tf.reduce_sum(tf.math.add(d_x_data, -d_g_z_2)**2, axis = 1)+EPSILON),
                   -tf.math.sqrt(tf.reduce_sum(d_x_data**2, axis = 1)+EPSILON))
        crit_g_1 = tf.math.add(tf.math.sqrt(tf.reduce_sum(tf.math.add(d_g_z_1, -d_g_z_2)**2, axis = 1)+EPSILON),
                   -tf.math.sqrt(tf.reduce_sum(d_g_z_1**2, axis = 1)+EPSILON))
        
        L_srg = tf.math.add(crit_r, -crit_g_1)
        
        with tf.GradientTape() as t_gp:
            t_gp.watch(x_it)
            d_it = self.d(x_it)
            crit_it = tf.math.add(tf.math.sqrt(tf.reduce_sum(tf.math.add(d_it, -d_g_z_2)**2, axis = 1)+EPSILON),
                   -tf.math.sqrt(tf.reduce_sum(d_it**2, axis = 1)+EPSILON))
            
        gp_grad = t_gp.gradient(crit_it, x_it)
        l2n_gp = tf.math.sqrt(tf.reduce_sum(gp_grad**2, axis = [1,2,3])+EPSILON)
        
        # d_loss
        L_d = tf.reduce_mean(-L_srg + (self.gp_lam*((l2n_gp-1.0)**2)))

        if update_gen:
            l2nrg1 = tf.math.sqrt(tf.reduce_sum(tf.math.add(d_x_data, -d_g_z_1)**2, axis = 1)+EPSILON)
            l2nrg2 = tf.math.sqrt(tf.reduce_sum(tf.math.add(d_x_data, -d_g_z_2)**2, axis = 1)+EPSILON)
            l2ng12 = tf.math.sqrt(tf.reduce_sum(tf.math.add(d_g_z_1, -d_g_z_2)**2, axis = 1)+EPSILON)

            # g_loss
            L_g = tf.reduce_mean(l2nrg1 + l2nrg2 - l2ng12)
        else:
            L_g = None

        return L_g, L_d
        
    
    @tf.function
    def update(self, imgs, update_gen=True):
        noise_input1 = tf.random.normal((imgs.shape[0], self.noise_latent_dim))
        noise_input2 = tf.random.normal((imgs.shape[0], self.noise_latent_dim))
        
        with tf.GradientTape() as g_tape, tf.GradientTape() as d_tape:
            g_z_1 = self.g(noise_input1)
            g_z_2 = self.g(noise_input2)
            
            d_x_data = self.d(imgs)
            d_g_z_1 = self.d(g_z_1)
            d_g_z_2 = self.d(g_z_2)
            
            epsi = tf.random.uniform([imgs.shape[0], 1, 1, 1], 0.0, 1.0)
            x_it = tf.math.add(epsi*imgs, (1.0-epsi)*g_z_1)
            g_loss, d_loss = self.cramer_loss(d_x_data, d_g_z_1, d_g_z_2, x_it, update_gen)

        if update_gen:
            grad_g = g_tape.gradient(g_loss, self.g.trainable_variables)
            grad_d = d_tape.gradient(d_loss, self.d.trainable_variables)

            self.g_opt.apply_gradients(zip(grad_g, self.g.trainable_variables))
            self.d_opt.apply_gradients(zip(grad_d, self.d.trainable_variables))
        else:
            grad_d = d_tape.gradient(d_loss, self.d.trainable_variables)
            self.d_opt.apply_gradients(zip(grad_d, self.d.trainable_variables))
            
        return g_loss, d_loss
        
    def train(self, epochs=250, train_moni_path=None):
        num_training = 0
        for epo in range(epochs):
            g_losses = []
            d_losses = []
            for img_b, l_b in self.dataset:
                if self.num_channel == 1 and img_b.shape[-1] == 3:
                    img_b = tf.image.rgb_to_grayscale(img_b)
                norm_img_b = (img_b-127.5)/127.5
                if num_training%self.disc_update_multi == 0:
                    g_l, d_l = self.update(norm_img_b, True)
                    g_losses.append(g_l.numpy())
                    d_losses.append(d_l.numpy())
                    
                else:
                    g_l, d_l = self.update(norm_img_b, False)
                    d_losses.append(d_l.numpy())
                    
                num_training = (num_training+1)%self.disc_update_multi
                
            print("Epoch {:04d}".format(epo), "Generator Avg. Loss: ", np.mean(g_losses), 
                  ", Discriminator Avg. Loss: ",  np.mean(d_losses), flush=True)
                
            if not train_moni_path == None:
                self.monitor_progress(epo, train_moni_path)
            
    def monitor_progress(self, epo, path):
        pics = self.g(self.g_seed)
        
        fig = plt.figure(figsize=(4,4))
        for i in range(pics.shape[0]):
            plt.subplot(4,4,i+1)
            if self.num_channel == 1:
                plt.imshow(pics[i,:,:,0], cmap='gray')
            else:   
                plt.imshow(tf.cast(tf.math.round(pics[i,:,:,:]*127.5+127.5), tf.int32))
            plt.axis('off')
            
        plt.savefig(path+'/image_{:04d}.png'.format(epo))
#         plt.savefig('/home/tony/TO_BE_REMOVED/imgs/image_{:04d}.png'.format(epo))
        # NEEDS to be changed for machines
        
        plt.close('all')
        
    def save_weights(self, g_path, d_path):
        self.g.save_weights(g_path)
        print("Saved generator weights", flush=True)
        self.d.save_weights(d_path)
        print("Saved discriminator weights", flush=True)
    def load_weights(self, g_path, d_path):
        try:
            self.g.load_weights(g_path)
            print("Loaded generator weights", flush=True)
            self.d.load_weights(d_path)
            print("Loaded discriminator weights", flush=True)
        except ValueError:
            print("ERROR: Please make sure weights are saved as .ckpt", flush=True)
    
    def generate_samples(self, num_sam, path):
        sam_seed = tf.random.normal((num_sam, self.noise_latent_dim))
        sam_pics = self.g(sam_seed)
        for i in range(sam_pics.shape[0]):
            if self.num_channel == 1:
                plt.imshow(sam_pics[i,:,:,0], cmap='gray')
            else:   
                plt.imshow(tf.cast(tf.math.round(sam_pics[i,:,:,:]*127.5+127.5), tf.int32))
            plt.axis('off')
            plt.savefig(path+'/image_{:04d}.png'.format(i))
            plt.close('all')
            

#### Testing

In [83]:
train_ds = tf.keras.utils.image_dataset_from_directory(
  "/home/tony/TO_BE_REMOVED/mnist_ds/mnist_jpg/training",
  seed=123,
  image_size=(32, 32),
  batch_size=256)

Found 60000 files belonging to 10 classes.


In [89]:
ds_path = "/home/tony/TO_BE_REMOVED/mnist_ds/mnist_jpg/training"

In [90]:
dcgan1 = DCGAN(ds_path, (32, 32), 1, 25, disc_update_multi=5)

Found 60000 files belonging to 10 classes.


In [91]:
dcgan1.train(25,'./imgs')
dcgan1.save_weights('./weights/g_test.ckpt', './weights/d_test.ckpt')

Epoch 0000 Generator Avg. Loss:  41.3724 , Discriminator Avg. Loss:  -29.27722
Epoch 0001 Generator Avg. Loss:  37.15732 , Discriminator Avg. Loss:  -25.44714
Epoch 0002 Generator Avg. Loss:  35.263073 , Discriminator Avg. Loss:  -23.928146
Epoch 0003 Generator Avg. Loss:  27.913044 , Discriminator Avg. Loss:  -17.721321
Epoch 0004 Generator Avg. Loss:  22.781689 , Discriminator Avg. Loss:  -12.914641
Epoch 0005 Generator Avg. Loss:  20.530334 , Discriminator Avg. Loss:  -10.875239
Epoch 0006 Generator Avg. Loss:  19.455324 , Discriminator Avg. Loss:  -9.466028
Epoch 0007 Generator Avg. Loss:  19.035072 , Discriminator Avg. Loss:  -8.864852
Epoch 0008 Generator Avg. Loss:  18.171598 , Discriminator Avg. Loss:  -7.996022
Epoch 0009 Generator Avg. Loss:  17.702473 , Discriminator Avg. Loss:  -7.164987
Epoch 0010 Generator Avg. Loss:  17.294514 , Discriminator Avg. Loss:  -6.5633073
Epoch 0011 Generator Avg. Loss:  16.996038 , Discriminator Avg. Loss:  -6.0955524
Epoch 0012 Generator Avg.

In [92]:
dcgan1.load_weights('./weights/g_test.ckpt', './weights/d_test.ckpt')

Loaded generator weights
Loaded discriminator weights


In [93]:
dcgan1.generate_samples(10, './samples/0')

In [94]:
dcgan2 = DCGAN(ds_path, (32, 32), 1, 25, disc_update_multi=5)

Found 60000 files belonging to 10 classes.


In [95]:
dcgan2.load_weights('./weights/g_test.ckpt', './weights/d_test.ckpt')

Loaded generator weights
Loaded discriminator weights


In [97]:
dcgan2.train(50, './imgs')

Epoch 0000 Generator Avg. Loss:  17.493753 , Discriminator Avg. Loss:  -4.0820417
Epoch 0001 Generator Avg. Loss:  17.52289 , Discriminator Avg. Loss:  -4.0324664
Epoch 0002 Generator Avg. Loss:  17.427225 , Discriminator Avg. Loss:  -3.9612927
Epoch 0003 Generator Avg. Loss:  17.620392 , Discriminator Avg. Loss:  -3.9564266
Epoch 0004 Generator Avg. Loss:  17.537416 , Discriminator Avg. Loss:  -3.9235697
Epoch 0005 Generator Avg. Loss:  17.557526 , Discriminator Avg. Loss:  -3.8974373
Epoch 0006 Generator Avg. Loss:  17.472876 , Discriminator Avg. Loss:  -3.8517308
Epoch 0007 Generator Avg. Loss:  17.587835 , Discriminator Avg. Loss:  -3.8309057
Epoch 0008 Generator Avg. Loss:  17.412464 , Discriminator Avg. Loss:  -3.783033
Epoch 0009 Generator Avg. Loss:  17.378683 , Discriminator Avg. Loss:  -3.7449489
Epoch 0010 Generator Avg. Loss:  17.583921 , Discriminator Avg. Loss:  -3.7130334
Epoch 0011 Generator Avg. Loss:  17.739964 , Discriminator Avg. Loss:  -3.6982608
Epoch 0012 Generat

In [98]:
dcgan2.save_weights('./weights/g_test.ckpt', './weights/d_test.ckpt')

Saved generator weights
Saved discriminator weights


In [99]:
dcgan1.load_weights('./weights/g_test.ckpt', './weights/d_test.ckpt')

Loaded generator weights
Loaded discriminator weights


In [100]:
dcgan1.generate_samples(10, './samples/1')

In [101]:
ds_celeba_path = "/home/tony/TO_BE_REMOVED/celeba_7z_alligned"

In [102]:
dcgan3 = DCGAN(ds_celeba_path, (64, 64), 3, 100, disc_update_multi=10)

Found 202599 files belonging to 1 classes.


In [None]:
dcgan3.train(1,'./imgs')
dcgan3.save_weights('./weights/g_celeb_test.ckpt', './weights/d_celeb_test.ckpt')