In [None]:
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
import tensorflow.keras as keras
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 
import string
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import PIL
import random
import time
from pathlib import Path

import re
from IPython import display

from transformers import TFAutoModel, AutoTokenizer
from tqdm import tqdm

In [None]:
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        # Restrict TensorFlow to only use the first GPU
        tf.config.experimental.set_visible_devices(gpus[0], 'GPU')

        # Currently, memory growth needs to be the same across GPUs
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        logical_gpus = tf.config.experimental.list_logical_devices('GPU')
        print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    except RuntimeError as e:
        # Memory growth must be set before GPUs have been initialized
        print(e)


In [None]:
dictionary_path = './dictionary'
vocab = np.load(dictionary_path + '/vocab.npy')
print('there are {} vocabularies in total'.format(len(vocab)))

word2Id_dict = dict(np.load(dictionary_path + '/word2Id.npy'))
id2word_dict = dict(np.load(dictionary_path + '/id2Word.npy'))
print('Word to id mapping, for example: %s -> %s' % ('flower', word2Id_dict['flower']))
print('Id to word mapping, for example: %s -> %s' % ('1', id2word_dict['1']))
print('Tokens: <PAD>: %s; <RARE>: %s' % (word2Id_dict['<PAD>'], word2Id_dict['<RARE>']))


In [None]:
def sent2IdList(line, MAX_SEQ_LENGTH=20):
    MAX_SEQ_LIMIT = MAX_SEQ_LENGTH
    padding = 0
    
    # data preprocessing, remove all puntuation in the texts
    prep_line = re.sub('[%s]' % re.escape(string.punctuation), ' ', line.rstrip())
    prep_line = prep_line.replace('-', ' ')
    prep_line = prep_line.replace('-', ' ')
    prep_line = prep_line.replace('  ', ' ')
    prep_line = prep_line.replace('.', '')
    tokens = prep_line.split(' ')
    tokens = [
        tokens[i] for i in range(len(tokens))
        if tokens[i] != ' ' and tokens[i] != ''
    ]
    l = len(tokens)
    padding = MAX_SEQ_LIMIT - l
    
    # make sure length of each text is equal to MAX_SEQ_LENGTH, and replace the less common word with <RARE> token
    for i in range(padding):
        tokens.append('<PAD>')
    line = [
        word2Id_dict[tokens[k]]
        if tokens[k] in word2Id_dict else word2Id_dict['<RARE>']
        for k in range(len(tokens))
    ]

    return line

text = "the flower shown has yellow anther red pistil and bright red petals."
print(text)
print(sent2IdList(text))


In [None]:
def sent2WordList(caption):
    return ' '.join(id2word_dict[str(i)] \
                    for i in caption \
                    if str(i) in id2word_dict and id2word_dict[str(i)][0] != '<')

In [None]:
print(sent2WordList(sent2IdList(text)))

# Text Encoder

In [None]:
class TFSentenceTransformer(keras.layers.Layer):
    def __init__(self, model_name):
        super().__init__()
        self.model = TFAutoModel.from_pretrained(model_name)

    def call(self, inputs, normalize=True):
        model_output = self.model(inputs)
        embeddings = self.mean_pooling(model_output, inputs['attention_mask'])
        if normalize:
            embeddings = self.normalize(embeddings)
        return embeddings

    def mean_pooling(self, model_output, attention_mask):
        token_embeddings = model_output[0]
        input_mask_expanded = tf.cast(
            tf.broadcast_to(tf.expand_dims(attention_mask, -1), tf.shape(token_embeddings)),
            tf.float32
        )
        return tf.math.reduce_sum(token_embeddings * input_mask_expanded, axis=1) \
                / tf.clip_by_value(tf.math.reduce_sum(input_mask_expanded, axis=1), 1e-9, tf.float32.max)

    def normalize(self, embeddings):
        embeddings, _ = tf.linalg.normalize(embeddings, 2, axis=1)
        return embeddings


class E2ESentenceTransformer(keras.Model):
    def __init__(self, model_name):
        super().__init__(name='text_encoder')
#         self.tokenizer = TFBertTokenizer.from_pretrained(model_name)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = TFSentenceTransformer(model_name)

    def call(self, inputs):
#         tokenized = self.tokenizer(inputs)
        tokenized = self.tokenizer(inputs, padding=True, truncation=True, return_tensors='tf')
        return self.model(tokenized)


# Generator

In [None]:
class Generator(keras.Model):
    def __init__(self, input_z_shape, emb_shape):
        super().__init__(name='generator')
        self.input_z_shape = input_z_shape
        self.emb_shape = emb_shape
        self.text = keras.Sequential([
            keras.layers.Flatten(), 
            keras.layers.Dense(384), 
            keras.layers.BatchNormalization(), 
            keras.layers.LeakyReLU(), 
        ])
        self.generator = keras.Sequential([
            keras.layers.Dense(8192, use_bias=False), 
            keras.layers.Reshape((4, 4, 512)), 
    
            keras.layers.Conv2DTranspose(
                512, 
                (4, 4), 
                strides=(1, 1), 
                padding='same', 
                use_bias=False, 
                kernel_initializer=keras.initializers.RandomNormal(mean=0.0, stddev=0.02), 
            ), 
            keras.layers.BatchNormalization(), 
            keras.layers.LeakyReLU(), 

            keras.layers.Conv2DTranspose(
                256, 
                (4, 4), 
                strides=(2, 2), 
                padding='same', 
                use_bias=False, 
                kernel_initializer=keras.initializers.RandomNormal(mean=0.0, stddev=0.02), 
            ), 
            keras.layers.BatchNormalization(), 
            keras.layers.LeakyReLU(), 
            keras.layers.Conv2DTranspose(
                128, 
                (4, 4), 
                strides=(2, 2), 
                padding='same', 
                use_bias=False, 
                kernel_initializer=keras.initializers.RandomNormal(mean=0.0, stddev=0.02), 
            ), 
            keras.layers.BatchNormalization(), 
            keras.layers.LeakyReLU(), 

            keras.layers.Conv2DTranspose(
                64, 
                (4, 4), 
                strides=(2, 2), 
                padding='same', 
                use_bias=False, 
                kernel_initializer=keras.initializers.RandomNormal(mean=0.0, stddev=0.02), 
            ), 
            keras.layers.BatchNormalization(), 
            keras.layers.LeakyReLU(), 

            keras.layers.Conv2DTranspose(
                3, 
                (4, 4), 
                strides=(2, 2), 
                padding='same', 
                use_bias=False, 
                activation='tanh', 
                kernel_initializer=keras.initializers.RandomNormal(mean=0.0, stddev=0.02), 
            ), 
        ])
        
    def call(self, text, noise_z):
        text = self.text(text)
        text_concat = tf.concat([noise_z, text], axis=1)
        output = self.generator(text_concat)
        return output

    def summary(self):
        text = keras.layers.Input(shape=(self.emb_shape, ), name='text')
        noise_z = keras.layers.Input(shape=(self.input_z_shape, ), name='noise_z')
        model = keras.Model(name='generator', inputs=[text, noise_z], outputs=self.call(text, noise_z))
        return model.summary()


# Discriminator

In [None]:
class Discriminator(keras.Model):
    def __init__(self, input_x_shape, emb_shape):
        super().__init__(name='generator')
        self.input_x_shape = input_x_shape
        self.emb_shape = emb_shape
        self.text = keras.Sequential([
            keras.layers.Flatten(), 
            keras.layers.Dense(384), 
            keras.layers.BatchNormalization(), 
            keras.layers.LeakyReLU(), 
        ])
        self.image = keras.Sequential([
            keras.layers.Conv2D(
                64, 
                (4, 4), 
                strides=(2, 2), 
                padding='same', 
                use_bias=False, 
                kernel_initializer=keras.initializers.RandomNormal(mean=0.0, stddev=0.02), 
            ), 
            keras.layers.LeakyReLU(), 

            keras.layers.Conv2D(
                128, 
                (4, 4), 
                strides=(2, 2), 
                padding='same', 
                use_bias=False, 
                kernel_initializer=keras.initializers.RandomNormal(mean=0.0, stddev=0.02), 
            ), 
            keras.layers.LeakyReLU(), 

            keras.layers.Conv2D(
                256, 
                (4, 4), 
                strides=(2, 2), 
                padding='same', 
                use_bias=False, 
                kernel_initializer=keras.initializers.RandomNormal(mean=0.0, stddev=0.02), 
            ), 
            keras.layers.LeakyReLU(), 
            keras.layers.Conv2D(
                512, 
                (4, 4), 
                strides=(2, 2), 
                padding='same', 
                use_bias=False, 
                kernel_initializer=keras.initializers.RandomNormal(mean=0.0, stddev=0.02), 
            ), 
            keras.layers.LeakyReLU(), 
        ])
        self.discriminator = keras.Sequential([
            keras.layers.Conv2D(
                512, 
                (1, 1), 
                strides=(1, 1), 
                padding='same', 
                use_bias=False, 
                kernel_initializer=keras.initializers.RandomNormal(mean=0.0, stddev=0.02), 
            ), 
            keras.layers.LeakyReLU(),
            
            keras.layers.Conv2D(
                1, 
                (4, 4), 
                strides=(1, 1), 
                padding='same', 
                use_bias=False, 
                kernel_initializer=keras.initializers.RandomNormal(mean=0.0, stddev=0.02), 
            ), 
            keras.layers.Flatten(), 
            keras.layers.Dense(1), 
        ])
        
    def call(self, text, img):
        text = self.text(text)
        text = tf.reshape(text, shape=(-1, 4, 4, text.shape[-1] // 16))
        img = self.image(img)
        text_concat = tf.concat([img, text], axis=3)
        output = self.discriminator(text_concat)
        return output

    def summary(self):
        text = keras.layers.Input(shape=(self.emb_shape, ), name='text')
        img = keras.layers.Input(shape=self.input_x_shape, name='img')
        model = keras.Model(name='discriminator', inputs=[text, img], outputs=self.call(text, img))
        return model.summary()


In [None]:
MODEL_NAME = 'sentence-transformers/paraphrase-multilingual-mpnet-base-v2'
EMBED_DIM = 768
Z_DIM = 128
IMAGE_SIZE = [64, 64, 3]

LEARNING_RATE = 1e-4

In [None]:
text_encoder = E2ESentenceTransformer(MODEL_NAME)
generator = Generator(Z_DIM, EMBED_DIM)
discriminator = Discriminator(IMAGE_SIZE, EMBED_DIM)

# Dataset

In [None]:
data_path = './dataset'
df = pd.read_pickle(data_path + '/text2ImgData.pkl')
num_training_sample = len(df)
n_images_train = num_training_sample
print('There are %d image in training data' % (n_images_train))


In [None]:

df.head(5)


In [None]:

# in this competition, you have to generate image in size 64x64x3
IMAGE_HEIGHT = 64
IMAGE_WIDTH = 64
IMAGE_CHANNEL = 3
NUM_CAP_PER_IMG = 3

def training_data_generator_aug(caption, image_path):
    # load in the image according to image path
    img = tf.io.read_file(image_path)
    img = tf.image.decode_jpeg(img, channels=3)
    img = tf.cast(img, tf.float32)
    # data augmentation
    img = tf.image.random_flip_left_right(img)
    img = tf.image.random_flip_up_down(img)
    img = tf.image.random_brightness(img, max_delta=0.04)
    img = tf.image.resize(img, size=[IMAGE_HEIGHT + IMAGE_HEIGHT // 10, IMAGE_WIDTH + IMAGE_WIDTH // 10])
    img = tf.image.random_crop(img, size=[IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNEL])
    img = (img / 255) * 2 - 1
    # caption = tf.cast(caption, tf.int32)

    return img, caption

def dataset_generator(filenames, batch_size, data_generator, num_rows=None):
    # load the training data into two NumPy arrays
    df = pd.read_pickle(filenames)
    if num_rows is not None:
        df = df.head(num_rows)
    captions = df['Captions'].values
    image_paths = df['ImagePath'].values

    caption = []
    image_path = []
    # each image has 1 to 10 corresponding captions
    # we choose all of them randomly for training
    for idx in range(len(captions)):
        for cap in captions[idx]:
            caption.append(cap)
            image_path.append(image_paths[idx])
    caption = np.asarray(caption)
    # caption = caption.astype(np.int)
    caption = [sent2WordList(cap) for cap in caption]
    # print(caption[:5])
    caption = tf.concat(
        [text_encoder(caption[64 * i:64 * min(len(caption), i + 1)]) \
             for i in range((len(caption) + 63) // 64)], 
        axis=0)
    caption = caption.numpy()
    caption = np.asarray(caption)
    caption = np.concatenate([caption] * NUM_CAP_PER_IMG, axis=0)
    image_path = np.concatenate([image_path] * NUM_CAP_PER_IMG, axis=0)

    # assume that each row of `features` corresponds to the same row as `labels`.
    assert caption.shape[0] == image_path.shape[0]
    
    dataset = tf.data.Dataset.from_tensor_slices((caption, image_path))
    dataset = dataset.map(data_generator, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    dataset = dataset.shuffle(len(caption)).batch(batch_size, drop_remainder=True)
    dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

    return dataset


In [None]:

random.seed(0)
tf.random.set_seed(0)
BATCH_SIZE = 8
dataset = dataset_generator(data_path + '/text2ImgData.pkl', BATCH_SIZE, training_data_generator_aug, num_rows=100)


In [None]:
import imageio

SAMPLE_ROW = 3
SAMPLE_COL = 4
SAMPLE_NUM = SAMPLE_ROW * SAMPLE_COL

def generate_img(imgs, row, col, path=None):
    h, w, c = imgs[0].shape
    out = np.zeros((h * row, w * col, c), dtype=np.uint8)
    for n, img in enumerate(imgs):
        j, i = divmod(n, col)
        out[j * h : (j + 1) * h, i * w : (i + 1) * w, :] = img
    if path is not None: 
        imageio.imwrite(path, out)
    return out

num_steps = len(dataset)
print(f'Num steps: {num_steps}')

imgs, caps = next(iter(dataset))

# Show caption
# caps = caps[:SAMPLE_NUM].numpy()
# for idx, cap in enumerate(caps):
#     i, j = divmod(idx, SAMPLE_COL)
#     print(f'({i+1},{j+1}): {cap}')
    
# Show image
imgs = tf.clip_by_value((imgs[:SAMPLE_NUM] + 1) / 2 * 255, 0, 255)
img = generate_img(imgs, SAMPLE_ROW, SAMPLE_COL)
plt.imshow(img)
plt.axis("off")
plt.show()

# LOSS


In [None]:
optimizer_g = keras.optimizers.Adam(learning_rate=LEARNING_RATE, beta_1=0.5)
optimizer_d = keras.optimizers.Adam(learning_rate=LEARNING_RATE, beta_1=0.5)


In [None]:
EPOCHS_PER_CKPT = 5

checkpoint_path = './ckpt-bert3'

ckpt = tf.train.Checkpoint(generator=generator,
                           discriminator=discriminator,
                           optimizer_g=optimizer_g,
                           optimizer_d=optimizer_d)

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=3)

start_epoch = 1
if ckpt_manager.latest_checkpoint:
    ckpt.restore(ckpt_manager.latest_checkpoint)
    latest_ckpt = int(ckpt_manager.latest_checkpoint.split('-')[-1])
    start_epoch = EPOCHS_PER_CKPT * latest_ckpt + 1
    print(f'Restore from latest checkpoint: {ckpt_manager.latest_checkpoint.split("/")[-1]}')
    
print(f'Start epoch: {start_epoch}')


In [None]:
LAMBDA = 10

@tf.function
def train_step_g(real_img, caption):
    input_g = tf.random.normal([BATCH_SIZE, Z_DIM])
    text = caption

    with tf.GradientTape() as tape_g:
        fake_img = generator(text, input_g, training=True)
        fake_pred = discriminator(text, fake_img, training=True)
        loss_g = -tf.reduce_mean(fake_pred)
        
    gradient_g = tape_g.gradient(loss_g, generator.trainable_variables)
    optimizer_g.apply_gradients(zip(gradient_g, generator.trainable_variables))
    
    return loss_g
    
@tf.function
def train_step_d(real_img, caption):
    input_g = tf.random.normal([BATCH_SIZE, Z_DIM])
    text = caption
    epsilon = tf.random.uniform(shape=[BATCH_SIZE, 1, 1, 1], minval=0, maxval=1)
    
    with tf.GradientTape() as tape_d:
        with tf.GradientTape() as tape_gp:
            fake_img = generator(text, input_g, training=True)
            fake_img_gp = epsilon * real_img + (1 - epsilon) * fake_img
            fake_pred_gp = discriminator(text, fake_img_gp, training=True)
        
        gradient_gp = tape_gp.gradient(fake_pred_gp, fake_img_gp)
        gradient_norm_gp = tf.sqrt(tf.reduce_sum(tf.square(gradient_gp), axis=np.arange(1, len(gradient_gp.shape))))
        gradient_penalty = tf.reduce_mean(tf.square(gradient_norm_gp - 1))
        
        fake_pred = discriminator(text, fake_img, training=True)
        real_pred = discriminator(text, real_img, training=True)
        
        loss_d = tf.reduce_mean(fake_pred) - tf.reduce_mean(real_pred) + LAMBDA * gradient_penalty
    
    gradient_d = tape_d.gradient(loss_d, discriminator.trainable_variables)
    optimizer_d.apply_gradients(zip(gradient_d, discriminator.trainable_variables))
    
    return loss_d


In [None]:
@tf.function
def test_step(caption):
    input_g = tf.random.normal([BATCH_SIZE, Z_DIM])
    text = caption
    fake_img = generator(text, input_g, training=False)
    return fake_img


# VISUAL

# 

In [None]:
visual_cap = [
    'flower with white long white petals and very long purple stamen ',
    'this medium white flower has rows of thin blue petals and thick stamen ',
    'this flower is white and purple in color with petals that are oval shaped ',
    'this flower is pink and yellow in color with petals that are oval shaped ',
    'the flower has a large bright orange petal with pink anther ',
    'the flower shown has a smooth white petal with patches of yellow as well ',
    'white petals that become yellow as they go to the center where there is an orange stamen ',
    'this flower has bright red petals with green pedicel as its main features ',
    'this flower has the overlapping yellow petals arranged closely toward the center ',
    'this flower has green sepals surrounding several layers of slightly ruffled pink petals ',
    'the pedicel on this flower is purple with a green sepal and rose colored petals ',
    'this white flower has connected circular petals with yellow stamen ',
    'the flower has yellow petals overlapping each other and are yellow in color ',
    'this flower has numerous stamen ringed by multiple layers of thin pink petals ',
    'the petals are broad but thin at the edges with purple tints at the edges and white in the middle ',
    'the yellow flower has petals that are soft smooth and arranged in two layers below the bunch of stamen ',
    'this flower has petals that are pink and yellow with yellow stamen ',
    'red stacked petals surround yellow stamen and a black pistil ',
    'the petals of the flower are in multiple layers and are pink in yellow in color ',
    'this flower has a yellow center and layers of peach colored petals with pointed tips ',
    'this bright pink flower has several fluttery petals and a tubular center ',
    'this flower is white and yellow in color with petals that are rounded at the endges ',
    'the flower has a several pieces of yellow colored petals that looks similar to its leaves ',
    'this flower has several light pink petals and yellow anthers ',
    'this flower is yellow and white in color with petals that are star shaped near the cener ',
    'lavender and white pedal and yellow small flower in the middle of the pedals ',
    'this flower has lavender petals with maroon stripes and brown anther filaments ',
    'this flower has six plain pale yellow petals that alternate with three dark yellow speckled petals ',
    'this flower has petals that are yellow with orange lines ',
    'the flower has petals that are orange with yellow stamen ',
    'this flower has a brown center surrounded by layers of long yellow petals with rounded tips ',
    'this flower is lavender in color with petals that are ruffled and wavy ',
    'this flower is blue in color with petals that have veins ',
    'the petals on this flower are white with yellow stamen ',
    'this flower has petals that are cone shaped and dark purple ',
    'this flower is purple and white in color and has petals that are multi colored ',
    'a large group of bells that are blue on this flower ',
    'this flower is bright purple with many pedals that are roundish whth pale white outer petals ',
    'this flower has large yellow petals and long yellow stamen on it ',
    'this flower has spiky blue petals and a spiky black stigma on it ',
    'this flower is purple and yellow in color with petals that are oval shaped ',
    'the flower has petals that are large and pink with yellow anther',
]


# Testing Dataset

In [None]:
def testing_data_generator(caption, index):
    caption = tf.cast(caption, tf.float32)
    return index, caption

def testing_dataset_generator(batch_size, data_generator):
    data = pd.read_pickle('./dataset/testData.pkl')
    captions = data['Captions'].values
    caption = np.asarray(captions)
#     caption = caption.astype(np.int32)
    caption = [sent2WordList(x) for x in caption]
    caption = tf.concat(
        [text_encoder(caption[64 * i:64 * min(len(caption), i + 1)]) \
             for i in range((len(caption) + 63) // 64)], 
        axis=0)
    caption = caption.numpy()
    caption = np.asarray(caption)
    index = data['ID'].values
    index = np.asarray(index)
    
    assert caption.shape[0] == index.shape[0]
    
    dataset = tf.data.Dataset.from_tensor_slices((caption, index))
    dataset = dataset.map(data_generator, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    dataset = dataset.repeat().batch(batch_size, drop_remainder=True)
    
    return dataset


In [None]:

testing_dataset  = testing_dataset_generator(BATCH_SIZE, testing_data_generator)


In [None]:
data = pd.read_pickle('./dataset/testData.pkl')
captions = data['Captions'].values

NUM_TEST = len(captions)
EPOCH_TEST = int(NUM_TEST / BATCH_SIZE)


# Inference

In [None]:
if not os.path.exists('./inference/demo'):
    os.makedirs('./inference/demo')


In [None]:
def inference(dataset):
    step = 0
    start = time.time()
    for idx, captions in dataset:
        if step > EPOCH_TEST:
            break
        
        fake_image = test_step(captions)
        step += 1
        for i in range(BATCH_SIZE):
            img = np.clip(fake_image[i].numpy() * 0.5 + 0.5, 0.0, 1.0)
            plt.imsave('./inference/demo/inference_{:04d}.jpg'.format(idx[i]), img)
            
    print('Time for inference is {:.4f} sec'.format(time.time() - start))

In [None]:
inference(testing_dataset)

# Training with inferencing

In [None]:
import subprocess

def run_command_with_epoch(epoch):
    """Run the command in the 'testing' directory and save results with an epoch postfix."""
    command_base = ["python", "inception_score.py", "../inference/demo"]
    output_file = f"../score_demo_epoch_{epoch:04d}.csv"
    command = command_base + [output_file, "39"]

    try:
        # Execute the command
        subprocess.run(command, check=True)
        print(f"Command executed successfully, output saved to {output_file}.")
    except subprocess.CalledProcessError as e:
        print(f"Command failed with error: {e}")
    except Exception as e:
        print(f"Unexpected error: {e}")

    return output_file

def eval_score(file_path):
    # Read the CSV file
    data = pd.read_csv(file_path)

    # Calculate the average of the 'score' column
    average_score = data['score'].mean()

    # Print the average score
    print(f'Average score: {average_score}')

In [None]:
if not os.path.exists('samples/demo'):
    os.makedirs('samples/demo')

EPOCHS = 2000
N_CRITIC = 3

loss_x_list = []
loss_g_list = []
loss_d_list = []

start = time.time()

sample_z = tf.random.normal([SAMPLE_NUM, Z_DIM])
sample_cap = visual_cap[:SAMPLE_NUM]

for epoch in range(start_epoch, EPOCHS + 1):
    
    loss_g = 0
    loss_d = 0
    
    # Train step
    for step, (img, cap) in tqdm(enumerate(dataset), total=num_steps):
        if step % (N_CRITIC + 1) == N_CRITIC - 1:
            loss_g += train_step_g(img, cap)
            # Store data to list
            loss_x_list.append(epoch + step / num_steps)
            loss_g_list.append(loss_g / 1)
            loss_d_list.append(loss_d / N_CRITIC)
            loss_g = 0
            loss_d = 0
        else:
            loss_d += train_step_d(img, cap)
    
    # Save checkpoint
    if epoch % EPOCHS_PER_CKPT == 0:
        ckpt_manager.save()
    
    # Generate sample image
    sample_text = text_encoder(sample_cap)
    sample_imgs = generator(sample_text, sample_z, training=False)
    sample_imgs = tf.clip_by_value((sample_imgs + 1) / 2 * 255, 0, 255)
    img = generate_img(sample_imgs, SAMPLE_ROW, SAMPLE_COL, f'imgs/{epoch:>04d}.png')
    
    # Display sample image
    if epoch % 5 == 0:
        plt.imshow(img)
        plt.axis("off")
        plt.show()
        inference(testing_dataset)
        
        # Change to the testing directory
        testing_dir = "./testing"
        original_dir = os.getcwd()  # Save the current directory
        os.chdir(testing_dir)
        # print(f"Changed directory to {testing_dir}. Running command: {' '.join(command)}")
        
        # Run command with the current epoch
        output_file = run_command_with_epoch(epoch)
        # Evaluate the score
        eval_score(output_file)

        # Change back to the original directory
        os.chdir(original_dir)
        # print(f"Returned to the original directory: {original_dir}.")

    print(f'Epoch {epoch}: Loss_g {loss_g_list[-1]:.5f}, Loss_d {loss_d_list[-1]:.5f}')
    
print (f'Time taken for {EPOCHS} epoch: {time.time() - start} sec')