In [1]:
import tensorflow as tf
from tensorflow import keras
import tensorflow_datasets as tfds
import tensorflow_probability as tfp

import numpy as np
import matplotlib.pyplot as plt

import os

In [2]:
from model import Discriminator_v1, Generator_v1
from datasets import download_celeb_data, chunk_datasets, re_chunk_datasets, DatManipulator
from objective import generator_loss_0, discrim_loss_0

In [171]:
uT = Timer()

In [3]:
cpus = tf.config.list_physical_devices("CPU")
gpus = tf.config.list_physical_devices("GPU")

In [4]:
gpus = [None]
global_batch_size = 3000
per_replica_batch = global_batch_size / len(gpus)
init_img_size = (32, 32)
targ_img_size = (256, 256)

In [5]:
strategy_one = tf.distribute.MirroredStrategy()

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)


In [42]:
# download_celeb_data()
path = "./data_celeb"
os.mkdir(path)
# chunk_datasets(path, num=8, del_left=True)
# re_chunk_datasets(path, num=5, del_left=True, suf_name='fold_chunk_')
data_object = DatManipulator(global_batch_size, path, strategy=strategy_one, target_image_size=targ_img_size, 
                             chunk_first=False, non_stop=True)


In [35]:
G = Generator_v1(init_img_size, targ_img_size, "deconv", latent_space=256)
D = Discriminator_v1(init_img_size, targ_img_size)


G.initialize_base(filters=(200, 200, 300, 400, 300))
D.initialize_base(filters=(200, 200, 300, 200, 100), units_dense=(130, ))

In [36]:
G.model.summary()

Model: "model_20"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_23 (InputLayer)       [(None, 256)]             0         
                                                                 
 map_z0 (Dense)              (None, 128)               32896     
                                                                 
 map_z1 (Dense)              (None, 128)               16512     
                                                                 
 map_z2 (Dense)              (None, 128)               16512     
                                                                 
 map_z3 (Dense)              (None, 200)               25800     
                                                                 
 batch_normalization_118 (Ba  (None, 200)              800       
 tchNormalization)                                               
                                                          

In [37]:
with strategy_one.scope():
    G_opt = tf.keras.optimizers.RMSprop(0.0001)
    D_opt = tf.keras.optimizers.RMSprop(0.0001)

In [43]:
@tf.function
def train_step(global_train_step, true_img_k, D, G, z_k, clip_norm=3, log_dir="TrainLog", snap_short=None,
              extended_model=False):
    if extended_model:
        assert 
        
    with tf.GradientTape() as D_tape:
        disc_L, disc_g_loss = discrim_loss_0(z=z_k, G=G, D=D, train_sets=true_img_k, g_penalty=grad_penalty)
    disc_G = D_tape.gradient(disc_L, D.model.trainable_variables)
    
    with tf.GradientTape() as G_tape:
        gen_L = generator_loss_0(z=z_k, G=G, D=D, train_sets=true_img_k)
    gen_G = G_tape.gradient(gen_L, G.model.trainable_variables)
    
    if clip_norm:
        disc_G = tf.clip_by_global_norm(disc_G, clip_norm, name='clip_norm_D')
        gen_G = tf.clip_by_global_norm(gen_G, clip_norm, name='clip_norm_G')
        
    D_opt.apply_gradients(zip(disc_G, D.model.trainable_variables))
    G_opt.apply_gradients(zip(gen_G, G.model.trainable_variables))
    
    grad_G_norm = tf.linalg.global_norm(gen_G)
    grad_D_norm = tf.linalg.global_norm(disc_G)
    
    with tf.name_scope("MainLoss"):
        tf.summary.scalar("G_Loss", tf.reduce_mean(disc_L), global_train_step)
        tf.summary.scalar("D_Loss", tf.reduce_mean(gen_L), global_train_step)
        tf.summary.scalar("G_grad_norm", tf.reduce_mean(grad_G_norm), global_train_step)
        tf.summary.scalar("D_grad_norm", tf.reduce_mean(grad_D_norm), global_train_step)
    
                          
    with tf.name_scope("SupLoss"):
        tf.summary.scalar("grad_D_penalty", tf.reduce_mean(disc_g_loss), global_train_step)
        if snap_short:
            raise NotImplementedError
            snap_image = None
            tf.summary.image("generated_image", snap_image, global_train_step, max_outputs=3)

In [45]:
def run():
    for k_k in range(1):
        z_prior = tf.random.normal()
        g_k_loss, d_k_loss, im_geg = strategy_one.run(train_step, args=(next(iter(data_object)), D, G))
        tf.print(k_k, 'generator_loss : ', tf.reduce_mean(g_k_loss), 'discriminator_loss : ', tf.reduce_mean(d_k_loss))
    

In [None]:
tf.random.normal()

In [49]:
tf.linalg.global_norm([np.random.rand(3, 10), np.random.rand(3, 12)])

<tf.Tensor: shape=(), dtype=float64, numpy=4.919951820182957>

In [100]:
u = 3000

In [101]:
u -= u - (u // 2)

In [102]:
u

1500

In [103]:
428 * 7

2996