In [None]:

import copy
import functools
import traceback

import imlib as im
import numpy as np
import pylib as py
import scipy
import tensorflow as tf
import tflib as tl
import tfprob
import tqdm

import data
import module


# ==============================================================================
# =                                   param                                    =
# ==============================================================================

py.arg('--img_dir', default='./data/img_celeba/aligned/align_size(572,572)_move(0.250,0.000)_face_factor(0.450)_jpg/data')
py.arg('--load_size', type=int, default=256)
py.arg('--crop_size', type=int, default=256)
py.arg('--n_channels', type=int, choices=[1, 3], default=3)

py.arg('--n_epochs', type=int, default=160)
py.arg('--epoch_start_decay', type=int, default=160)
py.arg('--batch_size', type=int, default=64)
py.arg('--learning_rate', type=float, default=1e-4)
py.arg('--beta_1', type=float, default=0.5)
py.arg('--moving_average_decay', type=float, default=0.999)

py.arg('--n_d', type=int, default=1)  # # d updates per g update
py.arg('--adversarial_loss_mode', choices=['gan', 'hinge_v1', 'hinge_v2', 'lsgan', 'wgan'], default='hinge_v1')
py.arg('--gradient_penalty_mode', choices=['none', '1-gp', '0-gp', 'lp'], default='0-gp')
py.arg('--gradient_penalty_sample_mode', choices=['line', 'real', 'fake', 'real+fake', 'dragan', 'dragan_fake'], default='real')

py.arg('--d_loss_weight_x_gan', type=float, default=1)
py.arg('--d_loss_weight_x_gp', type=float, default=10)
py.arg('--d_lazy_reg_period', type=int, default=3)

py.arg('--g_loss_weight_x_gan', type=float, default=1)
py.arg('--g_loss_weight_orth_loss', type=float, default=1)  # if 0, use Gram–Schmidt orthogonalization (slower)

py.arg('--d_attribute_loss_weight', type=float, default=1.0)
py.arg('--g_attribute_loss_weight', type=float, default=10.0)
py.arg('--g_reconstruction_loss_weight', type=float, default=100.0)

py.arg('--weight_decay', type=float, default=0)

py.arg('--z_dims', type=int, nargs='+', default=[6] * 6)
py.arg('--eps_dim', type=int, default=512)

py.arg('--n_samples', type=int, default=100)
py.arg('--n_traversal', type=int, default=5)
py.arg('--n_left_axis_point', type=int, default=10)
py.arg('--truncation_threshold', type=int, default=1.5)

py.arg('--sample_period', type=int, default=1000)
py.arg('--traversal_period', type=int, default=2500)
py.arg('--checkpoint_save_period', type=int, default=10000)

py.arg('--experiment_name', default='default')
#args = py.args()



In [None]:
args = py.args(["--experiment_name","Eigen128_0526_unet_recon100","--z_dims","7","7","7","7","7","--load_size","128","--crop_size","128"])
#args = py.args(["--experiment_name","Eigen128_0602_unet_recon100","--z_dims","7","7","7","7","7","--load_size","128","--crop_size","128"])


#args = py.args(["--experiment_name","Eigen256_0524_unet_l","--load_size","256","--crop_size","256","--batch_size","32"])


In [None]:



import functools

import numpy as np
import tensorflow as tf
import tensorflow.contrib.slim as slim
import tflib as tl


from tqdm.auto import tqdm, trange
from pdb import set_trace
    


class DD(tl.Module):

    def call(self,
             x,
             n_atts,
             dim_10=4,
             fc_dim=1024,
             n_downsamplings=6,
             weight_norm='none',
             feature_norm='none',
             act=tf.nn.leaky_relu,
             training=True):
        MAX_DIM = 512
        nd = lambda size: min(int(2**(10 - np.log2(size)) * dim_10), MAX_DIM)

        w_norm = tl.get_weight_norm(weight_norm, training)
        conv = functools.partial(tl.conv2d, weights_initializer=tl.get_initializer(act), weights_normalizer_fn=w_norm, weights_regularizer=slim.l2_regularizer(1.0))
        fc = functools.partial(tl.fc, weights_initializer=tl.get_initializer(act), weights_normalizer_fn=w_norm, weights_regularizer=slim.l2_regularizer(1.0))

        f_norm = tl.get_feature_norm(feature_norm, training, updates_collections=None)
        conv_norm_act = functools.partial(conv, normalizer_fn=f_norm, activation_fn=act)

        
        h = x
        h = act(conv(h, nd(h.shape[1].value), 7, 1))
        for i in range(n_downsamplings):
            # h = conv_norm_act(h, nd(h.shape[1].value // 2), 4, 2)
            h = conv_norm_act(h, nd(h.shape[1].value), 3, 1)
            h = conv_norm_act(h, nd(h.shape[1].value // 2), 3, 2)

            
        h = conv_norm_act(h, nd(h.shape[1].value), 3, 1)
        h = slim.flatten(h)
        h = act(fc(h, min(fc_dim, MAX_DIM)))
        logit_gan = fc(h, 1)
        logit_att = fc(h, n_atts)

        return logit_gan, logit_att
    
    
    
class UNetGenc(tl.Module):

    def call(self,
             x,
             dim_10=4,
             n_channels=3,
             n_downsamplings=6,
             weight_norm='none',
             feature_norm='none',
             act=tf.nn.leaky_relu,
             training=True):
        MAX_DIM = 512
        nd = lambda size: min(int(2**(10 - np.log2(size)) * dim_10), MAX_DIM)

        w_norm = tl.get_weight_norm(weight_norm, training)
        conv = functools.partial(tl.conv2d, weights_initializer=tl.get_initializer(act), weights_normalizer_fn=w_norm, weights_regularizer=slim.l2_regularizer(1.0))
        fc = functools.partial(tl.fc, weights_initializer=tl.get_initializer(act), weights_normalizer_fn=w_norm, weights_regularizer=slim.l2_regularizer(1.0))

        f_norm = tl.get_feature_norm(feature_norm, training, updates_collections=None)
        conv_norm_act = functools.partial(conv, normalizer_fn=f_norm, activation_fn=act)

        hiddenLayer = []
        
        h = x
        h = act(conv(h, nd(h.shape[1].value), 7, 1))
        for i in range(n_downsamplings):
            # h = conv_norm_act(h, nd(h.shape[1].value // 2), 4, 2)
            h = conv_norm_act(h, nd(h.shape[1].value), 3, 1)
            hiddenLayer.append(h)

            h = conv_norm_act(h, nd(h.shape[1].value // 2), 3, 2)
            hiddenLayer.append(h)

        return hiddenLayer
    
    


class UNetGdec(tl.Module):

    def call(self,
             zs,
             eps,
             dim_10=4,
             n_channels=3,
             weight_norm='none',
             feature_norm='none',
             act=tf.nn.leaky_relu,
             use_gram_schmidt=True,
             training=True,
            shortcut_layers=1):
        MAX_DIM = 512
        nd = lambda size: min(int(2**(10 - np.log2(size)) * dim_10), MAX_DIM)

        w_norm = tl.get_weight_norm(weight_norm, training)
        transposed_w_norm = tl.get_weight_norm(weight_norm, training, transposed=True)
        fc = functools.partial(tl.fc, weights_initializer=tl.get_initializer(act), weights_normalizer_fn=w_norm, weights_regularizer=slim.l2_regularizer(1.0))
        conv = functools.partial(tl.conv2d, weights_initializer=tl.get_initializer(act), weights_normalizer_fn=w_norm, weights_regularizer=slim.l2_regularizer(1.0))
        dconv = functools.partial(tl.dconv2d, weights_initializer=tl.get_initializer(act), weights_normalizer_fn=transposed_w_norm, weights_regularizer=slim.l2_regularizer(1.0))
        f_norm = tl.get_feature_norm(feature_norm, training, updates_collections=None)
        f_norm = (lambda x: x) if f_norm is None else f_norm

        def orthogonal_regularizer(U):
            with tf.name_scope('orthogonal_regularizer'):
                U = tf.reshape(U, [-1, U.shape[-1]])
                orth = tf.matmul(tf.transpose(U), U)
                tf.add_to_collections(['orth'], orth)
                return 0.5 * tf.reduce_sum((orth - tf.eye(U.shape[-1].value)) ** 2)

            
        h=eps[-1]
        
        
        for i, z in enumerate(zs):
            height = width = 4 * 2 ** i

            U = tf.get_variable('U_%d' % i,
                                shape=[height, width, nd(height), z.shape[-1]],
                                initializer=tf.initializers.orthogonal(),
                                regularizer=orthogonal_regularizer,
                                trainable=True)
            if use_gram_schmidt:
                U = tf.transpose(tf.reshape(U, [-1, U.shape[-1]]))
                U = tl.gram_schmidt(U)
                U = tf.reshape(tf.transpose(U), [height, width, nd(height), z.shape[-1]])

            L = tf.get_variable('L_%d' % i,
                                shape=[z.shape[-1]],
                                initializer=tf.initializers.constant([3 * i for i in range(z.shape[-1], 0, -1)]),
                                trainable=True)

            mu = tf.get_variable('mu_%d' % i,
                                 shape=[height, width, nd(height)],
                                 initializer=tf.initializers.zeros(),
                                 trainable=True)

            h_ = tf.reduce_sum(U[None, ...] * (L[None, :] * z)[:, None, None, None, :], axis=-1) + mu[None, ...]

            h_1 = dconv(h_, nd(height), 1, 1)
            
            if shortcut_layers > i:
                h_2 = dconv(h_, nd(height * 2)*2, 3, 2)
            else:
                h_2 = dconv(h_, nd(height * 2), 3, 2)
            
            
            #deconv1
            h=act(f_norm(h + h_1))
            #if shortcut_layers > i:
            #    h = tl.tile_concat([h, eps[-1 - 2*i]])
            h = dconv(h, nd(height * 2), 3, 2)

            
            
            if shortcut_layers > i:
                h = tl.tile_concat([h, eps[-2 - 2*i]])
            #deconv2
            h=act(f_norm(h + h_2))
            h = dconv(h, nd(height * 2), 3, 1)
            
        x = tf.tanh(conv(act(h), n_channels, 7, 1))

        return x    

    
    
    

In [None]:
import numpy as np
import pylib as py
import tensorflow as tf
import tflib as tl


def make_dataset(img_paths,
                 batch_size,
                 load_size=286,
                 crop_size=256,
                 n_channels=3,
                 training=True,
                 drop_remainder=True,
                 shuffle=True,
                 repeat=1):
    

    if shuffle:
        img_paths = np.random.permutation(img_paths)

    if training:
        def _map_fn(img):
            if n_channels == 1:
                img = tf.image.rgb_to_grayscale(img)
            img = tf.image.resize(img, [load_size, load_size])
            img = tf.image.random_flip_left_right(img)
            img = tl.center_crop(img, size=crop_size)
            # img = tf.image.random_crop(img, [crop_size, crop_size, n_channels])
            img = tf.clip_by_value(img, 0, 255) / 127.5 - 1
            return img
    else:
        def _map_fn(img):
            if n_channels == 1:
                img = tf.image.rgb_to_grayscale(img)
            img = tf.image.resize(img, [load_size, load_size])
            img = tl.center_crop(img, size=crop_size)
            img = tf.clip_by_value(img, 0, 255) / 127.5 - 1
            return img

    dataset = tl.disk_image_batch_dataset(img_paths,
                                          batch_size,
                                          drop_remainder=drop_remainder,
                                          map_fn=_map_fn,
                                          shuffle=shuffle,
                                          repeat=repeat)

    if drop_remainder:
        len_dataset = len(img_paths) // batch_size
    else:
        len_dataset = int(np.ceil(len(img_paths) / batch_size))

    return dataset, len_dataset



In [None]:

    
# check
assert np.log2(args.crop_size / 4) == len(args.z_dims)

# output_dir
output_dir = py.join('output', args.experiment_name)
py.mkdir(output_dir)

# save settings
py.args_to_yaml(py.join(output_dir, 'settings.yml'), args)

sess = tl.session()


# ==============================================================================
# =                                    data                                    =
# ==============================================================================


img_paths=sorted(py.glob(args.img_dir, '*'))
img_paths_train = img_paths[:int(len(img_paths)*0.95)]
img_paths_test = img_paths[int(len(img_paths)*0.95):]


train_dataset, len_train_dataset = make_dataset(img_paths_train, args.batch_size, load_size=args.load_size, crop_size=args.crop_size, n_channels=args.n_channels, repeat=None)
train_iter = train_dataset.make_one_shot_iterator()

val_dataset, len_val_dataset = make_dataset(img_paths_test, max(args.n_traversal, args.n_samples), load_size=args.load_size, crop_size=args.crop_size, n_channels=args.n_channels, shuffle=False,repeat=None,training=False)
val_iter = val_dataset.make_one_shot_iterator()


# ==============================================================================
# =                                   model                                    =
# ==============================================================================

#D = functools.partial(module.D(scope='D'), n_downsamplings=len(args.z_dims))
D = functools.partial(DD(scope='D'), n_atts=sum(args.z_dims), n_downsamplings=len(args.z_dims))
#G = functools.partial(module.G(scope='G'), n_channels=args.n_channels, use_gram_schmidt=args.g_loss_weight_orth_loss == 0)
Genc = functools.partial(UNetGenc(scope='Gdec'), n_channels=args.n_channels, n_downsamplings=len(args.z_dims))
Gdec = functools.partial(UNetGdec(scope='Genc'), n_channels=args.n_channels, use_gram_schmidt=args.g_loss_weight_orth_loss == 0)
G_test = functools.partial(UNetGdec(scope='G_test'), n_channels=args.n_channels, use_gram_schmidt=args.g_loss_weight_orth_loss == 0, training=False)

# exponential moving average
G_ema = tf.train.ExponentialMovingAverage(decay=args.moving_average_decay, name='G_ema')

# loss function
d_loss_fn, g_loss_fn = tfprob.get_adversarial_losses_fn(args.adversarial_loss_mode)


# ==============================================================================
# =                                   graph                                    =
# =============================================================================



def D_train_graph():
    # ======================================
    # =               graph                =
    # ======================================

    # placeholders & inputs
    lr = tf.placeholder(dtype=tf.float32, shape=[])
    x_r = train_iter.get_next()
    zs = [tf.random.normal([args.batch_size, z_dim]) for z_dim in args.z_dims]
    eps = tf.random.normal([args.batch_size, args.eps_dim])

    # counter
    step_cnt, _ = tl.counter()

    # optimizer
    optimizer = tf.train.AdamOptimizer(lr, beta1=args.beta_1)

    def graph_per_gpu(x_r, zs, eps):

        # generate
        eps=Genc(x_r)
        x_f=Gdec(zs,eps)
        
        # discriminate
        x_r_logit,_ = D(x_r)
        x_f_logit,x_f_logit_att = D(x_f)

        # loss
        x_r_loss, x_f_loss = d_loss_fn(x_r_logit, x_f_logit)
        x_gp = tf.cond(tf.equal(step_cnt % args.d_lazy_reg_period, 0),
                       lambda: tfprob.gradient_penalty(D, x_r, x_f, args.gradient_penalty_mode, args.gradient_penalty_sample_mode) * args.d_lazy_reg_period,
                       lambda: tf.constant(0.0))
        if args.d_loss_weight_x_gp == 0:
            x_gp = tf.constant(0.0)

        reg_loss = tf.reduce_sum(D.func.reg_losses)

       
        zs_flatten = tf.concat(zs,axis=1)
        xb__loss_att=tf.losses.mean_squared_error(zs_flatten, x_f_logit_att)  
        
        
        loss = (
            (x_r_loss + x_f_loss) * args.d_loss_weight_x_gan +
            x_gp * args.d_loss_weight_x_gp +
            reg_loss * args.weight_decay +
            xb__loss_att * args.d_attribute_loss_weight
        )
        
        

        # optim
        grads = optimizer.compute_gradients(loss, var_list=D.func.trainable_variables)

        return grads, x_r_loss, x_f_loss, x_gp, reg_loss

    split_grads, split_x_r_loss, split_x_f_loss, split_x_gp, split_reg_loss = zip(*tl.parellel_run(tl.gpus(), graph_per_gpu, tl.split_nest((x_r, zs, eps), len(tl.gpus()))))
    # split_grads, split_x_r_loss, split_x_f_loss, split_x_gp, split_reg_loss = zip(*tl.parellel_run(['cpu:0'], graph_per_gpu, tl.split_nest((x_r, zs, eps), 1)))
    grads = tl.average_gradients(split_grads)
    x_r_loss, x_f_loss, x_gp, reg_loss = [tf.reduce_mean(t) for t in [split_x_r_loss, split_x_f_loss, split_x_gp, split_reg_loss]]

    step = optimizer.apply_gradients(grads, global_step=step_cnt)

    # summary
    summary = tl.create_summary_statistic_v2(
        {'x_gan_loss': x_r_loss + x_f_loss,
         'x_gp': x_gp,
         'reg_loss': reg_loss,
         'lr': lr},
        './output/%s/summaries/D' % args.experiment_name,
        step=step_cnt,
        n_steps_per_record=10,
        name='D'
    )

    # ======================================
    # =            run function            =
    # ======================================

    def run(**pl_ipts):
        for _ in range(args.n_d):
            sess.run([step, summary], feed_dict={lr: pl_ipts['lr']})

    return run


def G_train_graph():
    # ======================================
    # =               graph                =
    # ======================================

    # placeholders & inputs
    lr = tf.placeholder(dtype=tf.float32, shape=[])
    zs = [tf.random.normal([args.batch_size, z_dim]) for z_dim in args.z_dims]
    eps = tf.random.normal([args.batch_size, args.eps_dim])
    x_r = train_iter.get_next()
    
    # counter
    step_cnt, _ = tl.counter()

    # optimizer
    optimizer = tf.train.AdamOptimizer(lr, beta1=args.beta_1)

    def graph_per_gpu(zs, eps):
        # generate
        _,zs_a = D(x_r)
        zs_a=tf.split(zs_a, len(args.z_dims), axis=1)
        
        eps=Genc(x_r)
        x_f=Gdec(zs,eps)
        x_a=Gdec(zs_a,eps)
        
        
        # discriminate
        x_f_logit,xb__logit_att = D(x_f)

        # loss
        x_f_loss = g_loss_fn(x_f_logit)
        orth_loss = tf.reduce_sum(tl.tensors_filter(Gdec.func.reg_losses, 'orthogonal_regularizer'))
        reg_loss_Gdec = tf.reduce_sum(tl.tensors_filter(Gdec.func.reg_losses, 'l2_regularizer'))
        reg_loss_Genc = tf.reduce_sum(tl.tensors_filter(Genc.func.reg_losses, 'l2_regularizer'))
        reg_loss=reg_loss_Gdec+reg_loss_Genc
                
        zs_flatten = tf.concat(zs,axis=1)
        xb__loss_att= xb__loss_att=tf.losses.mean_squared_error(zs_flatten, xb__logit_att)  
        xa__loss_rec = tf.losses.absolute_difference(x_r, x_a)
        
        loss = (
            x_f_loss * args.g_loss_weight_x_gan +
            orth_loss * args.g_loss_weight_orth_loss +
            reg_loss * args.weight_decay +
            xb__loss_att * args.g_attribute_loss_weight +
            xa__loss_rec *  args.g_reconstruction_loss_weight 
        )

        
        # optim
        #grads = optimizer.compute_gradients(loss, var_list=G.func.trainable_variables)
        grads = optimizer.compute_gradients(loss, var_list=Genc.func.trainable_variables+Gdec.func.trainable_variables)

        return grads, x_f_loss, orth_loss, reg_loss

    split_grads, split_x_f_loss, split_orth_loss, split_reg_loss = zip(*tl.parellel_run(tl.gpus(), graph_per_gpu, tl.split_nest((zs, eps), len(tl.gpus()))))
    # split_grads, split_x_f_loss, split_orth_loss, split_reg_loss = zip(*tl.parellel_run(['cpu:0'], graph_per_gpu, tl.split_nest((zs, eps), 1)))
    grads = tl.average_gradients(split_grads)
    x_f_loss, orth_loss, reg_loss = [tf.reduce_mean(t) for t in [split_x_f_loss, split_orth_loss, split_reg_loss]]

    step = optimizer.apply_gradients(grads, global_step=step_cnt)

    # moving average
    with tf.control_dependencies([step]):
        step = G_ema.apply(Gdec.func.trainable_variables)

        

        
    # summary
    summary_dict = {'x_f_loss': x_f_loss,
                    'orth_loss': orth_loss,
                    'reg_loss': reg_loss}
    summary_dict.update({'L_%d' % i: t for i, t in enumerate(tl.tensors_filter(Genc.func.trainable_variables+Gdec.func.trainable_variables, 'L'))})
    summary_loss = tl.create_summary_statistic_v2(
        summary_dict,
        './output/%s/summaries/G' % args.experiment_name,
        step=step_cnt,
        n_steps_per_record=10,
        name='G_loss'
    )

    summary_image = tl.create_summary_image_v2(
        {'orth_U_%d' % i: t[None, :, :, None] for i, t in enumerate(tf.get_collection('orth', Gdec.func.scope + '/'))},
        './output/%s/summaries/G' % args.experiment_name,
        step=step_cnt,
        n_steps_per_record=10,
        name='G_image'
    )

    # ======================================
    # =             model size             =
    # ======================================

    n_params, n_bytes = tl.count_parameters(Genc.func.trainable_variables+Gdec.func.trainable_variables)
    print('Model Size: n_parameters = %d = %.2fMB' % (n_params, n_bytes / 1024 / 1024))

    # ======================================
    # =            run function            =
    # ======================================

    def run(**pl_ipts):
        sess.run([step, summary_loss, summary_image], feed_dict={lr: pl_ipts['lr']})

    return run


def sample_graph():

    # ======================================
    # =               graph                =
    # ======================================

    # placeholders & inputs
    zs = [tl.truncated_normal([args.n_samples, z_dim], minval=-args.truncation_threshold, maxval=args.truncation_threshold) for z_dim in args.z_dims]
    eps = tl.truncated_normal([args.n_samples, args.eps_dim], minval=-args.truncation_threshold, maxval=args.truncation_threshold)
    xa = tf.placeholder(tf.float32, shape=[None, args.crop_size, args.crop_size, 3])
    
    
   
    
    # generate
    x_r = val_iter.get_next()
    x_f = G_test(zs,Genc(xa, training=False), training=False)
    
    # ======================================
    # =            run function            =
    # ======================================

    save_dir = './output/%s/samples_training/sample' % (args.experiment_name)
    py.mkdir(save_dir)

    def run(epoch, iter):
        xa_ipt = sess.run(x_r)
        
        x_f_opt = sess.run(x_f, feed_dict={xa: xa_ipt[:args.n_samples]})
        sample = im.immerge(x_f_opt, n_rows=int(args.n_samples ** 0.5))
        im.imwrite(sample, '%s/Epoch-%d_Iter-%d.jpg' % (save_dir, epoch, iter))

    return run




def traversal_graph():

    # ======================================
    # =               graph                =
    # ======================================

    # placeholders & inputs
    zs = [tf.placeholder(dtype=tf.float32, shape=[args.n_traversal, z_dim]) for z_dim in args.z_dims]
    eps = tf.placeholder(dtype=tf.float32, shape=[args.n_traversal, args.eps_dim])
    x = tf.placeholder(tf.float32, shape=[None, args.crop_size, args.crop_size, 3])
    
    # generate
    x_r = val_iter.get_next()
    _,x_r_zs=D(x, training=False)
    x_r_zs=tf.split(x_r_zs, len(args.z_dims), axis=1)
    
    
    x_f = G_test(zs,Genc(x, training=False), training=False)
    
    
    # ======================================
    # =            run function            =
    # ======================================

    save_dir = './output/%s/samples_training/traversal' % (args.experiment_name)
    py.mkdir(save_dir)

    def run(epoch, iter):
        x_r_input = sess.run(x_r)
        x_r_input=x_r_input[:args.n_traversal]
        x_r_zs_input=sess.run(x_r_zs, feed_dict={x:x_r_input})
        x_r_zs_input=np.array(x_r_zs_input)
        feed_dict = {z: z_ipt for z, z_ipt in zip(zs, x_r_zs_input)}
        feed_dict.update({x: x_r_input})
        
        x_f_recon= sess.run(x_f,feed_dict=feed_dict)
    
        zs_ipt_fixed=x_r_zs_input
        #zs_ipt_fixed = [scipy.stats.truncnorm.rvs(-args.truncation_threshold, args.truncation_threshold, size=[args.n_traversal, z_dim]) for z_dim in args.z_dims]
        #eps_ipt = scipy.stats.truncnorm.rvs(-args.truncation_threshold, args.truncation_threshold, size=[args.n_traversal, args.eps_dim])
        # set the first sample as the "mode"
        #for l in range(len(args.z_dims)):
        #    zs_ipt_fixed[l][0, ...] = 0.0
        #eps_ipt[0, ...] = 0.0
        

        L_opt = sess.run(tl.tensors_filter(G_test.func.variables, 'L'))
        for l in range(len(args.z_dims)):
            for j, i in enumerate(np.argsort(np.abs(L_opt[l]))[::-1]):
                x_f_opts = [x_r_input,x_f_recon]
                
                
                
                vals = np.linspace(-4.5, 4.5, args.n_left_axis_point * 2 + 1)
                
                for v in vals:
                    zs_ipt = copy.deepcopy(zs_ipt_fixed)
                    zs_ipt[l][:, i] = v
                    feed_dict = {z: z_ipt for z, z_ipt in zip(zs, zs_ipt)}
                    feed_dict.update({x: x_r_input})
                    x_f_opt = sess.run(x_f, feed_dict=feed_dict)
                    x_f_opts.append(x_f_opt)

                sample = im.immerge(np.concatenate(x_f_opts, axis=2), n_rows=args.n_traversal)
                im.imwrite(sample, '%s/Epoch-%d_Iter-%d_Traversal-%d-%d-%.3f-%d.jpg' % (save_dir, epoch, iter, l, j, np.abs(L_opt[l][i]), i))

    return run


def clone_graph():
    # ======================================
    # =               graph                =
    # ======================================

    clone_tr = G_test.func.clone_from_vars(tl.tensors_filter(tl.global_variables(), 'G_ema'), var_type='trainable')
    clone_non = G_test.func.clone_from_module(Gdec.func, var_type='nontrainable')

    # ======================================
    # =            run function            =
    # ======================================

    def run(**pl_ipts):
        sess.run([clone_tr, clone_non])

    return run


d_train_step = D_train_graph()
g_train_step = G_train_graph()
sample = sample_graph()
traversal = traversal_graph()
clone = clone_graph()


    
    
    

    
# ==============================================================================
# =                                   train                                    =
# ==============================================================================

# init
checkpoint, step_cnt, update_cnt = tl.init(py.join(output_dir, 'checkpoints'), checkpoint_max_to_keep=1, session=sess)

# learning rate schedule
lr_fn = tl.LinearDecayLR(args.learning_rate, args.n_epochs, args.epoch_start_decay)

    

In [None]:

# train
try:
    for ep in trange(args.n_epochs, desc='Epoch Loop'):
        # learning rate
        lr_ipt = lr_fn(ep)

        for it in trange(len_train_dataset // (args.n_d + 1), desc='Inner Epoch Loop'):
            if it + ep * (len_train_dataset // (args.n_d + 1)) < sess.run(step_cnt):
                continue
            step = sess.run(update_cnt)

            # train D
            d_train_step(lr=lr_ipt)
            # train G
            g_train_step(lr=lr_ipt)

            # save
            if step % args.checkpoint_save_period == 0:
                checkpoint.save(step, session=sess)

            # sample
            if step % args.sample_period == 0 :
                clone()
                sample(ep, it)
            if step % args.traversal_period == 0 :
                clone()
                traversal(ep, it)
except Exception:
    traceback.print_exc()
finally:
    clone()
    sample(ep, it)
    traversal(ep, it)
    checkpoint.save(step, session=sess)
    sess.close()

    
    
    
    

In [None]:

#display sample


from IPython.display import display
from PIL import Image
from imlib import dtype



def display_sample():

    # ======================================
    # =               graph                =
    # ======================================

    
    # placeholders & inputs

    zs = [tl.truncated_normal([args.n_samples, z_dim], minval=-args.truncation_threshold, maxval=args.truncation_threshold) for z_dim in args.z_dims]
    eps = tl.truncated_normal([args.n_samples, args.eps_dim], minval=-args.truncation_threshold, maxval=args.truncation_threshold)
    xa = tf.placeholder(tf.float32, shape=[None, args.crop_size, args.crop_size, 3])

    # generate
    x_r = val_iter.get_next()
    _,x_r_zs=D(xa, training=False)
    x_r_zs=tf.split(x_r_zs, len(args.z_dims), axis=1)
    
    x_f_rand = G_test(zs,Genc(xa, training=False), training=False)
    x_f_recon = G_test(x_r_zs,Genc(xa, training=False), training=False)


    
    # ======================================
    # =            run function            =
    # ======================================
    
    def run():
        xa_ipt = sess.run(x_r)[:args.n_samples]
        x_f_opt_rand = sess.run(x_f_rand,  feed_dict={xa: xa_ipt})
        x_f_opt_recon = sess.run(x_f_recon,  feed_dict={xa: xa_ipt})


        
        img=Image.fromarray(dtype.im2uint(xa_ipt[0]))
        display(img)
        img=Image.fromarray(dtype.im2uint(x_f_opt_recon[0]))
        display(img)
        img=Image.fromarray(dtype.im2uint(x_f_opt_rand[0]))
        display(img)

        

    return run


display_sample_func=display_sample()
display_sample_func()

