### Members:
- 109065511 張宜禎
- 109062562 蔡哲維
- 108065425 丘騏銘

# Import and setup

In [None]:
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
from tensorflow.keras import layers
import os
import imageio
import string
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import PIL
import random
import time
from pathlib import Path

import re
from IPython import display

In [None]:
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        # Restrict TensorFlow to only use the first GPU
        tf.config.experimental.set_visible_devices(gpus[0], 'GPU')

        # Currently, memory growth needs to be the same across GPUs
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        logical_gpus = tf.config.experimental.list_logical_devices('GPU')
        print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    except RuntimeError as e:
        # Memory growth must be set before GPUs have been initialized
        print(e)

# Hyperparameters and Utilities

In [None]:
BATCH_SIZE = 100
BUFFER_SIZE = 20000
DATASET_SIZE = 211485

SAMPLE_COL = 8
SAMPLE_ROW = 8
SAMPLE_NUM = SAMPLE_COL * SAMPLE_ROW

CAPTION_NUM = 70495

IMAGE_HEIGHT = 64
IMAGE_WIDTH = 64
IMAGE_CHANNEL = 3


In [None]:
hparas = {
    'EMBED_DIM': 1024,                         # word embedding dimension
    'Z_DIM': 256,                             # random noise z dimension
    'DENSE_DIM': 128,                         # number of neurons in dense layer
    'IMAGE_SIZE': [64, 64, 3],                # render image size
    'BATCH_SIZE': 256,
    'LR': 1e-4,
    'LR_DECAY': 0.5,
    'BETA_1': 0.5,
    'N_EPOCH': 600,
    'N_SAMPLE': DATASET_SIZE // BATCH_SIZE,          # size of training data
    'rs_Train': float(BATCH_SIZE) / float(DATASET_SIZE), 
    'CHECKPOINTS_DIR': './checkpoints/train',  # checkpoint path
    'PRINT_FREQ': 1,                       # printing frequency of loss
    'BZ':(BATCH_SIZE,256),
    'TEST_Z':(SAMPLE_NUM,256),
    'TEST_BATCH_SIZE':91
}

In [None]:
# Utility function
def utPuzzle(imgs, row, col, path=None):
    h, w, c = imgs[0].shape
    out = np.zeros((h * row, w * col, c), np.uint8)
    for n, img in enumerate(imgs):
        j, i = divmod(n, col)
        out[j * h : (j + 1) * h, i * w : (i + 1) * w, :] = img
    if path is not None : imageio.imwrite(path, out)
    return out
  
def utMakeGif(imgs, fname, duration):
    n = float(len(imgs)) / duration
    clip = mpy.VideoClip(lambda t : imgs[int(n * t)], duration = duration)
    clip.write_gif(fname, fps = n)

In [None]:
dictionary_path = './dictionary'
vocab = np.load(dictionary_path + '/vocab.npy')
print('there are {} vocabularies in total'.format(len(vocab)))

word2Id_dict = dict(np.load(dictionary_path + '/word2Id.npy'))
id2word_dict = dict(np.load(dictionary_path + '/id2Word.npy'))
print('Word to id mapping, for example: %s -> %s' % ('flower', word2Id_dict['flower']))
print('Id to word mapping, for example: %s -> %s' % ('1', id2word_dict['1']))
print('Tokens: <PAD>: %s; <RARE>: %s' % (word2Id_dict['<PAD>'], word2Id_dict['<RARE>']))

In [None]:
def sent2IdList(line, MAX_SEQ_LENGTH=20):
    MAX_SEQ_LIMIT = MAX_SEQ_LENGTH
    padding = 0
    
    # data preprocessing, remove all puntuation in the texts
    prep_line = re.sub('[%s]' % re.escape(string.punctuation), ' ', line.rstrip())
    prep_line = prep_line.replace('-', ' ')
    prep_line = prep_line.replace('-', ' ')
    prep_line = prep_line.replace('  ', ' ')
    prep_line = prep_line.replace('.', '')
    tokens = prep_line.split(' ')
    tokens = [
        tokens[i] for i in range(len(tokens))
        if tokens[i] != ' ' and tokens[i] != ''
    ]
    l = len(tokens)
    padding = MAX_SEQ_LIMIT - l
    
    # make sure length of each text is equal to MAX_SEQ_LENGTH, and replace the less common word with <RARE> token
    for i in range(padding):
        tokens.append('<PAD>')
    line = [
        word2Id_dict[tokens[k]]
        if tokens[k] in word2Id_dict else word2Id_dict['<RARE>']
        for k in range(len(tokens))
    ]

    return line

text = "the flower shown has yellow anther red pistil and bright red petals."
print(text)
print(sent2IdList(text))


In [None]:
def idx2word(indices_list):
    results_list = []
    for indices in indices_list:
        string = ''
        length_of_string = 0
        for idx in indices:
            if idx == '5428':
                string = string + ''
            elif idx == '5427':
                break
            else:
                string = string + id2word_dict[idx] + ' '
        results_list.append(string.strip())
    return results_list

# Explore the data

In [None]:
data_path = './dataset'
df = pd.read_pickle(data_path + '/text2ImgData.pkl')
num_training_sample = len(df)
n_images_train = num_training_sample
print('There are %d image in training data' % (n_images_train))

In [None]:
df.head(5)

In [None]:
df = df.reset_index(drop=True)

# Preprocessing

In [None]:
df['texts'] = df['Captions'].apply(lambda x: idx2word(x))

In [None]:
df.loc[0,'texts']

In [None]:
def remove_empty_string(string_list):
    empty_flag = False
    for string in string_list:
        if string == '':
            empty_flag = True
            break
    if empty_flag == False:
        return string_list
    else:
        new_string_list = []
        for string in string_list:
            if string != '':
                new_string_list.append(string)
        return new_string_list       

In [None]:
df['texts'] = df['texts'].apply(lambda x: remove_empty_string(x))

In [None]:
df

#### See the number of captions of each image 

In [None]:
def count_caption_num(string_list):
    return len(string_list)

In [None]:
df['caption_num'] = df['texts'].apply(lambda c: count_caption_num(c))

In [None]:
num_dict = {}
for num in df['caption_num'].tolist():
    if num in num_dict:
        num_dict[num]+=1
    else:
        num_dict[num]=1

In [None]:
num_dict

# Get BERT Embedding

In [None]:
from transformers import BertTokenizer, TFBertModel

bert_tokenizer = BertTokenizer.from_pretrained(
    'bert-large-uncased', 
    do_lower_case=False,
    do_basic_tokenize=False
)
bert_model = TFBertModel.from_pretrained('bert-large-uncased')

In [None]:
def turn_to_bert_embedding(string_list):
    try:
        bert_inputs = bert_tokenizer(string_list, return_tensors="tf", padding='max_length',max_length=30)
        bert_outputs = bert_model(bert_inputs)
        caption_embedding = bert_outputs.last_hidden_state[:,0]
    except(ValueError):
        print(string_list)
    return caption_embedding.numpy().tolist()

test_string = ['this flower is white and pink in color with petals that have small veins', 'the flower shown has a purple and white petal with white anther', 'the four heart shaped pink petals of this flower are striped with fuchsia and their centers are yellow and white']  
print(len(turn_to_bert_embedding(test_string)))

In [None]:
from datetime import datetime

print("{}, start infering.".format(datetime.now()))
df['embeddings'] = df['texts'].apply(lambda x : turn_to_bert_embedding(x))
print("{}, end infering.".format(datetime.now()))

In [None]:
len(df.loc[0,'embeddings'][0])

In [None]:
df.to_pickle("./dataset/text2img_cls_embedding.pkl")

# Write images into tf record

In [None]:
df = pd.read_pickle("./dataset/text2img_cls_embedding.pkl")

In [None]:
def _bytes_feature(value):
    """Returns a bytes_list from a string / byte."""
    if isinstance(value, type(tf.constant(0))):
        value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
    """Returns a float_list from a float / double."""
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))

In [None]:
def load_img(img_path):
    raw_img = tf.io.read_file(img_path)
    ##################################
    #img = tf.image.decode_jpeg(raw_img, channels=3)
    #img = tf.image.resize(img, size=[IMAGE_HEIGHT, IMAGE_WIDTH])
    #img = tf.cast(img,tf.float32)
    #img = img / 255.
    ##################################
    return raw_img

In [None]:
def serialize_example(img):
    feature = {
        'img': _bytes_feature(img)
    }
    
    example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
    
    return example_proto.SerializeToString()

In [None]:
def tf_serialize_example(img):
    tf_string = tf.py_function(
        serialize_example,
        [img],  # pass these args to the above function.
        tf.string)      # the return type is `tf.string`.
    return tf.reshape(tf_string, ()) # The result is a scalar

In [None]:
image_paths = df['ImagePath'].values

image_paths = np.asarray(image_paths)

assert image_paths.shape[0] == 7370

In [None]:
ITEMS_PER_FILE = 3000
num=0
for i in range(0,len(image_paths),ITEMS_PER_FILE):
    write_record_dataset = tf.data.Dataset.from_tensor_slices(image_paths[i:i+ITEMS_PER_FILE])
    write_record_dataset = write_record_dataset.map(load_img,num_parallel_calls=tf.data.experimental.AUTOTUNE)
    write_record_dataset = write_record_dataset.map(tf_serialize_example,num_parallel_calls=tf.data.experimental.AUTOTUNE)
    filename = f'train_{num:03d}_.tfrecord'
    writer = tf.data.experimental.TFRecordWriter(filename)
    writer.write(write_record_dataset)
    num+=1


# Read from tf record

In [None]:
filenames = ['train_000_.tfrecord','train_001_.tfrecord','train_002_.tfrecord']
raw_dataset_train = tf.data.TFRecordDataset(filenames)
raw_dataset_train

In [None]:
feature_description = {
    'img': tf.io.FixedLenFeature([], tf.string),
}

def _parse_function(example_proto):
    parsed = tf.io.parse_single_example(example_proto, feature_description)
    return parsed['img']

In [None]:
raw_dataset_train = raw_dataset_train.map(_parse_function)
raw_dataset_train

In [None]:
def processing(img):
    img = tf.image.decode_jpeg(img, channels=3)
    img = tf.image.resize(img, size=[IMAGE_HEIGHT, IMAGE_WIDTH])
    img = tf.cast(img,tf.float32)
    img = img / 255.
    return img

In [None]:
raw_dataset_train = raw_dataset_train.map(processing)

In [None]:
number_of_img = 0

imgs = []
for img in raw_dataset_train:
    imgs.append(img.numpy())

In [None]:
print(len(imgs))

In [None]:
for img in imgs:
    plt.imshow(img)
    plt.axis("off")
    plt.show()

# Create tf Dataset

In [None]:
def training_data_generator(img, embedding):
    img = tf.image.convert_image_dtype(img, tf.float32)
    img.set_shape([IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNEL])
    img = img*2. - 1.
    embedding = tf.cast(embedding, tf.float32)

    return img, embedding

def flip_right_left_data_generator(img, embedding):
    img = tf.image.convert_image_dtype(img, tf.float32)
    img.set_shape([IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNEL])
    img = img*2. -1.
    img = tf.image.flip_left_right(img)
    embedding = tf.cast(embedding, tf.float32)

    return img, embedding


def adjust_brightness_data_generator(img, embedding):
    img = tf.image.convert_image_dtype(img, tf.float32)
    img = tf.image.random_brightness(img, 0.2, 2)
    img.set_shape([IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNEL])
    img = img*2. -1.
    img = tf.image.flip_up_down(img)
    embedding = tf.cast(embedding, tf.float32)

    return img, embedding

In [None]:
df = pd.read_pickle("./dataset/text2img_cls_embedding.pkl")

embeddings = df['embeddings'].values

embedding = []

img_for_dataset = []

for i in range(len(embeddings)):
    for emb in embeddings[i]:
        embedding.append(emb)
        img_for_dataset.append(imgs[i])
embedding = np.asarray(embedding)
img_for_dataset = np.asarray(img_for_dataset)

assert embedding.shape[0] == img_for_dataset.shape[0]

In [None]:
dataset = tf.data.Dataset.from_tensor_slices((img_for_dataset, embedding))
dataset = dataset.map(training_data_generator,num_parallel_calls=tf.data.experimental.AUTOTUNE)
flip_right_left_dataset = tf.data.Dataset.from_tensor_slices((img_for_dataset, embedding))
flip_right_left_dataset = flip_right_left_dataset.map(flip_right_left_data_generator,num_parallel_calls=tf.data.experimental.AUTOTUNE)
adjust_brightness_dataset = tf.data.Dataset.from_tensor_slices((img_for_dataset, embedding))
adjust_brightness_dataset = adjust_brightness_dataset.map(adjust_brightness_data_generator,num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.concatenate(flip_right_left_dataset)
dataset = dataset.concatenate(adjust_brightness_dataset)

In [None]:
num = 0
for img, emb in dataset:
    num+=1
num

In [None]:
dataset = dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)
dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

In [None]:
dataset

# Modified DCGAN with WGAN-GP Loss

In [None]:
class Generator(tf.keras.Model):
    """
    Generate fake image based on given text(hidden representation) and noise z
    input: text and noise
    output: fake image with size 64*64*3
    """
    def __init__(self, hparas):
        super(Generator, self).__init__()
        self.hparas = hparas
        self.compress = tf.keras.layers.Dense(self.hparas['DENSE_DIM'])
        self.to_4_4_1024 = tf.keras.layers.Dense(4*4*1024)
        
        self.bn1 = tf.keras.layers.BatchNormalization()
        self.bn2 = tf.keras.layers.BatchNormalization()
        self.bn3 = tf.keras.layers.BatchNormalization()
        self.bn4 = tf.keras.layers.BatchNormalization()
        
        self.lr1 = tf.keras.layers.LeakyReLU()
        self.lr2 = tf.keras.layers.LeakyReLU()
        self.lr3 = tf.keras.layers.LeakyReLU()
        self.lr4 = tf.keras.layers.LeakyReLU()
        self.lr5 = tf.keras.layers.LeakyReLU()
        
        self.dc1 = tf.keras.layers.Conv2DTranspose(
            filters = 512,
            kernel_size = 5,
            strides = 2,
            padding = "SAME"
        )
        self.dc2 = tf.keras.layers.Conv2DTranspose(
            filters = 256,
            kernel_size = 5,
            strides = 2,
            padding = "SAME"
        )
        self.dc3 = tf.keras.layers.Conv2DTranspose(
            filters = 128,
            kernel_size = 5,
            strides = 2,
            padding = "SAME"
        )
        self.dc4 = tf.keras.layers.Conv2DTranspose(
            filters = 3,
            kernel_size = 5,
            strides = 2,
            padding = "SAME"
        )
        
        
    def call(self, noise_z, text, training):
        # compress the embedding
        text = self.compress(text)
        text = self.lr1(text)
        
        # concatenate input text and random noise
        text_concat = tf.concat([noise_z, text], axis=1)
        
        # To 4*4*1024
        text_concat = self.to_4_4_1024(text_concat)
        text_concat = tf.reshape(text_concat, [-1, 4, 4, 1024])
        text_concat = self.bn1(text_concat,training=training)
        text_concat = self.lr2(text_concat)
        
        # To 8*8*512
        text_concat = self.dc1(text_concat)
        text_concat = self.bn2(text_concat,training=training)
        text_concat = self.lr3(text_concat)
        
        # To 16*16*256
        text_concat = self.dc2(text_concat)
        text_concat = self.bn3(text_concat,training=training)
        text_concat = self.lr4(text_concat)
        
        # To 32*32*128
        text_concat = self.dc3(text_concat)
        text_concat = self.bn4(text_concat,training=training)
        text_concat = self.lr5(text_concat)
        
        # To 64*64*3
        text_concat = self.dc4(text_concat)
        
        output = tf.nn.tanh(text_concat)
        
        return output

In [None]:
class Discriminator(tf.keras.Model):
    """
    Differentiate the real and fake image
    input: image and corresponding text
    output: labels, the real image should be 1, while the fake should be 0
    """
    def __init__(self, hparas):
        super(Discriminator, self).__init__()
        self.hparas = hparas
        self.compress = tf.keras.layers.Dense(self.hparas['DENSE_DIM'])
        self.d = tf.keras.layers.Dense(1)
        
        self.relu = tf.keras.layers.ReLU()
        self.relu2 = tf.keras.layers.ReLU()
        
        self.bn1 = tf.keras.layers.BatchNormalization()
        self.bn2 = tf.keras.layers.BatchNormalization()
        self.bn3 = tf.keras.layers.BatchNormalization()
        self.bn4 = tf.keras.layers.BatchNormalization()
        self.bn5 = tf.keras.layers.BatchNormalization()
        
        self.lr1 = tf.keras.layers.LeakyReLU()
        self.lr2 = tf.keras.layers.LeakyReLU()
        self.lr3 = tf.keras.layers.LeakyReLU()
        self.lr4 = tf.keras.layers.LeakyReLU()
        self.lr5 = tf.keras.layers.LeakyReLU()
        
        
        self.conv1 = tf.keras.layers.Conv2D(
            filters = 128,
            kernel_size = 5,
            strides = (2, 2),
            padding = "SAME",
            input_shape = (64,64,3))
        
        self.conv2 = tf.keras.layers.Conv2D(
            filters = 256,
            kernel_size = 5,
            strides = (2, 2),
            padding = "SAME")
        
        self.conv3 = tf.keras.layers.Conv2D(
            filters = 512,
            kernel_size = 5,
            strides = (2, 2),
            padding = "SAME")
        
        self.conv4 = tf.keras.layers.Conv2D(
            filters = 1024,
            kernel_size = 5,
            strides = (2, 2),
            padding = "SAME")
        
        self.conv5 = tf.keras.layers.Conv2D(
            filters = 1024,
            kernel_size = 1,
            strides = (1, 1),
            padding = "SAME")
    
    def call(self, img, text, training):
        # Conpress embedding
        text = self.compress(text)
        text = self.relu(text)
        # To 32*32*128
        img = self.conv1(img)
        #img = self.bn1(img,training=training)
        img = self.lr1(img)
        # To 16*16*256
        img = self.conv2(img)
        #img = self.bn2(img,training=training)
        img = self.lr2(img)
        # To 8*8*512
        img = self.conv3(img)
        #img = self.bn3(img,training=training)
        img = self.lr3(img)
        # To 4*4*1024
        img = self.conv4(img)
        #img = self.bn4(img,training=training)
        img = self.lr4(img)
        
        # concatenate image with paired text
        text = tf.expand_dims(text,axis=1)
        text = tf.expand_dims(text,axis=1)
        text = tf.tile(text,multiples=[1,4,4,1])
        img_text = tf.concat([img, text], axis=-1)
        
        img_text = self.conv5(img_text)
        #img_text = self.bn5(img_text,training=training)
        img_text = self.relu2(img_text)
        
        img_text = tf.reshape(img_text, [-1, 4*4*1024])
        
        
        score = self.d(img_text)
        return score

In [None]:
generator = Generator(hparas)
discriminator = Discriminator(hparas)

In [None]:
generator_optimizer = tf.keras.optimizers.Adam(hparas['LR'])
discriminator_optimizer = tf.keras.optimizers.Adam(hparas['LR'])

In [None]:
checkpoint_path = hparas['CHECKPOINTS_DIR']
ckpt = tf.train.Checkpoint(generator = generator,
                           discriminator = discriminator,
                           generator_optimizer = generator_optimizer,
                           discriminator_optimizer = discriminator_optimizer)
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=10)

In [None]:
@tf.function
def DC_D_Train(c1,embed,noise_decay):
    z = tf.random.normal(hparas['BZ']) 

    with tf.GradientTape() as tp:
        with tf.GradientTape() as tp_2:
            x_bar = generator(z, embed, training = True)
            epsilon = tf.random.uniform([BATCH_SIZE,1,1,1])
            x = c1
            x_hat = epsilon * x + (1. - epsilon) * x_bar

            x_bar = x_bar + noise_decay * tf.random.normal(x_bar.shape)
            x = x + noise_decay * tf.random.normal(x.shape)
            x_hat = x_hat + noise_decay * tf.random.normal(x_hat.shape)

            z0 = discriminator(x_bar, embed, training = True)
            z1 = discriminator(x, embed, training = True)
            z2 = discriminator(x_hat, embed, training = True)

            gradient_penalty = tp_2.gradient(z2,x_hat)
            gradient_penalty = tf.sqrt(tf.reduce_sum(tf.math.square(gradient_penalty),axis=[1,2,3]))
            loss = z0 - z1 + 10. * tf.math.square((gradient_penalty - 1.))
            ld = tf.reduce_mean(loss)
            lg = - tf.reduce_mean(z0)

    gradient_d = tp.gradient(ld, discriminator.trainable_variables)

    discriminator_optimizer.apply_gradients(zip(gradient_d, discriminator.trainable_variables))

    return lg, ld

@tf.function
def DC_G_Train(c1,embed,noise_decay):
    
    z = tf.random.normal(hparas['BZ'])

    with tf.GradientTape() as tp:
        with tf.GradientTape() as tp_2:
            x_bar = generator(z, embed, training = True)
            epsilon = tf.random.uniform([BATCH_SIZE,1,1,1])
            x = c1
            x_hat = epsilon * x + (1. - epsilon) * x_bar

            x_bar = x_bar + noise_decay * tf.random.normal(x_bar.shape)
            x = x + noise_decay * tf.random.normal(x.shape)
            x_hat = x_hat + noise_decay * tf.random.normal(x_hat.shape)

            z0 = discriminator(x_bar, embed, training = True)
            z1 = discriminator(x, embed, training = True)
            z2 = discriminator(x_hat, embed, training = True)
            gradient_penalty = tp_2.gradient(z2,x_hat)
            gradient_penalty = tf.sqrt(tf.reduce_sum(tf.math.square(gradient_penalty),axis=[1,2,3]))
            loss = z0 - z1 + 10. * tf.math.square((gradient_penalty - 1.))
            ld = tf.reduce_mean(loss)
            lg = - tf.reduce_mean(z0)

    gradient_g = tp.gradient(lg, generator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradient_g, generator.trainable_variables))

    return lg, ld

In [None]:
@tf.function
def test_step(noise, embed):
    fake_image = generator(noise, embed, training = False)
    return fake_image

In [None]:
df = pd.read_pickle("./dataset/text2img_cls_embedding.pkl")

embeddings = df['embeddings'].values
ids = df['Captions'].values

test_embed = []

for i in range(8):
    if len(embeddings[i]) >= 8:
        for j in range(8):
            test_embed.append(embeddings[i][j])
test_noise = tf.random.normal(hparas['TEST_Z'])
test_embed = tf.Variable(test_embed)

In [None]:
Train = (
    DC_D_Train,
    DC_D_Train,
    DC_D_Train,
    DC_D_Train,
    DC_D_Train,
    DC_G_Train
)

Critic = len(Train)

In [None]:
wlg = [None] * hparas['N_EPOCH'] #record loss of g for each epoch
wld = [None] * hparas['N_EPOCH']  #record loss of d for each epoch
wsp = [None] * hparas['N_EPOCH']  #record sample images for each epoch

rsTrain = hparas['rs_Train']
ctr = 0
for ep in range(hparas['N_EPOCH']):
    print("Epoch: " + str(ep+1), end='\r')
    print('')
    lgt = 0.0
    ldt = 0.0
    if ep < 200:
        noise_decay = 1.0 / float(ep+1)
    else:
        noise_decay = 0.0
        
    for idx, (real_img,embed) in enumerate(dataset):
        print(str(idx+1) + '/' + str(hparas['N_SAMPLE']), end='\r')
        lg, ld = Train[ctr](real_img, embed, noise_decay)
        ctr += 1
        lgt += lg.numpy()
        ldt += ld.numpy()
        if ctr == Critic : ctr = 0
    print('')
    wlg[ep] = lgt * rsTrain
    wld[ep] = ldt * rsTrain
    with open('./wlg_v2.txt','a') as f:
        f.write(str(lgt * rsTrain) + '\n')
    f.close()
    with open('./wld_v2.txt','a') as f:
        f.write(str(ldt * rsTrain) + '\n')
    f.close()
    
    out = test_step(test_noise, test_embed)
    img = utPuzzle(
        ((out+1) / 2. * 255.0).numpy().astype(np.uint8),
        SAMPLE_COL,
        SAMPLE_ROW,
        "imgs_v2/w_%04d.png" % ep
    )
    wsp[ep] = img
    if (ep+1) % 10 == 0: 
        plt.imshow(img)
        plt.axis("off")
        plt.title("Epoch %d" % (ep+1))
        plt.show()
    if (ep+1) % 10 == 0: 
        ckpt_manager.save()