In [None]:
import pandas as pd
import numpy as np

# Preprocess Text

In [None]:
dictionary_path = './dictionary'


id2word_dict = dict(np.load(dictionary_path + '/id2Word.npy'))

def caption2str(caption):
    # caption is a list of word "ids"
    return ' '.join(id2word_dict[str(i)] \
                    for i in caption \
                    if str(i) in id2word_dict and id2word_dict[str(i)][0] != '<')


# Dataset

In [None]:
data_path = './dataset'
df = pd.read_pickle(data_path + '/text2ImgData.pkl')
num_training_sample = len(df)
n_images_train = num_training_sample
print('There are %d image in training data' % (n_images_train))


# Create Dataset by Dataset API


In [None]:
IMAGE_SIZE = 64
BATCH_SIZE = 32
BUFFER_SIZE = 5000

def preprocess_data(image_path, caption):
    img = tf.io.read_file(image_path)
    img = tf.image.decode_jpeg(img, channels=3)
    
    img = tf.cast(img, tf.float32)
    
    # data augmentation
    img = tf.image.random_flip_left_right(img)
    img = tf.image.random_flip_up_down(img)
    img = tf.image.random_brightness(img, max_delta=0.04)
    img = tf.image.resize(img, (IMAGE_SIZE + IMAGE_SIZE // 10, IMAGE_SIZE + IMAGE_SIZE // 10))
    img = tf.image.random_crop(img, (IMAGE_SIZE, IMAGE_SIZE, 3))
    
    img = (img / 255) * 2 - 1
    caption = tf.cast(caption, tf.float32)
    return img, caption

def dataset_generator(pkl_file):
    df = pd.read_pickle(pkl_file)
    captions = df['Captions'].values
    image_paths = df['ImagePath'].values
    
    caption = []
    image_path = []
    for i, caps in enumerate(captions):
        for cap in caps:
            caption.append(cap)
            image_path.append(image_paths[i])
            
    caption = np.asarray(caption)
    caption = caption.astype(np.int32)
    caption = [caption2str(cap) for cap in caption]
    caption = tf.concat(
        [text_encoder(caption[64 * i:64 * min(len(caption), i + 1)]) \
             for i in range((len(caption) + 63) // 64)], 
        axis=0)
    caption = caption.numpy()
    caption = np.asarray(caption)
    caption = np.concatenate([caption] * NUM_CAP_PER_IMG, axis=0)
    image_path = np.concatenate([image_path] * NUM_CAP_PER_IMG, axis=0)
    
    assert caption.shape[0] == image_path.shape[0]
    
    dataset = tf.data.Dataset.from_tensor_slices((image_path, caption))
    dataset = dataset.map(preprocess_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    dataset = dataset.shuffle(BUFFER_SIZE)
    dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)
    dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
    
    return dataset


In [None]:
dataset = dataset_generator(data_path + '/text2ImgData.pkl')
num_steps = len(dataset)
print(f'Num steps: {num_steps}')


# GAN

# Text ENcoder

In [None]:
class TFSentenceTransformer(keras.layers.Layer):
    def __init__(self, model_name):
        super().__init__()
        self.model = TFAutoModel.from_pretrained(model_name)

    def call(self, inputs, normalize=True):
        model_output = self.model(inputs)
        embeddings = self.mean_pooling(model_output, inputs['attention_mask'])
        if normalize:
            embeddings = self.normalize(embeddings)
        return embeddings

    def mean_pooling(self, model_output, attention_mask):
        token_embeddings = model_output[0]
        input_mask_expanded = tf.cast(
            tf.broadcast_to(tf.expand_dims(attention_mask, -1), tf.shape(token_embeddings)),
            tf.float32
        )
        return tf.math.reduce_sum(token_embeddings * input_mask_expanded, axis=1) \
                / tf.clip_by_value(tf.math.reduce_sum(input_mask_expanded, axis=1), 1e-9, tf.float32.max)

    def normalize(self, embeddings):
        embeddings, _ = tf.linalg.normalize(embeddings, 2, axis=1)
        return embeddings


class E2ESentenceTransformer(keras.Model):
    def __init__(self, model_name):
        super().__init__(name='text_encoder')
#         self.tokenizer = TFBertTokenizer.from_pretrained(model_name)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = TFSentenceTransformer(model_name)

    def call(self, inputs):
#         tokenized = self.tokenizer(inputs)
        tokenized = self.tokenizer(inputs, padding=True, truncation=True, return_tensors='tf')
        return self.model(tokenized)
    

# # model_name = 'sentence-transformers/all-MiniLM-L6-v2'
# model_name = 'sentence-transformers/paraphrase-multilingual-mpnet-base-v2'
# text_encoder = E2ESentenceTransformer(model_name)


# Generator

In [None]:
class Generator(keras.Model):
    def __init__(self, input_z_shape, emb_shape):
        super().__init__(name='generator')
        self.input_z_shape = input_z_shape
        self.emb_shape = emb_shape
        self.text = keras.Sequential([
            keras.layers.Flatten(), 
            keras.layers.Dense(384), 
            keras.layers.BatchNormalization(), 
            keras.layers.LeakyReLU(), 
        ])
        self.generator = keras.Sequential([
            keras.layers.Dense(8192, use_bias=False), 
            keras.layers.Reshape((4, 4, 512)), 
    
            keras.layers.Conv2DTranspose(
                512, 
                (4, 4), 
                strides=(1, 1), 
                padding='same', 
                use_bias=False, 
                kernel_initializer=keras.initializers.RandomNormal(mean=0.0, stddev=0.02), 
            ), 
            keras.layers.BatchNormalization(), 
            keras.layers.LeakyReLU(), 

            keras.layers.Conv2DTranspose(
                256, 
                (4, 4), 
                strides=(2, 2), 
                padding='same', 
                use_bias=False, 
                kernel_initializer=keras.initializers.RandomNormal(mean=0.0, stddev=0.02), 
            ), 
            keras.layers.BatchNormalization(), 
            keras.layers.LeakyReLU(), 
            keras.layers.Conv2DTranspose(
                128, 
                (4, 4), 
                strides=(2, 2), 
                padding='same', 
                use_bias=False, 
                kernel_initializer=keras.initializers.RandomNormal(mean=0.0, stddev=0.02), 
            ), 
            keras.layers.BatchNormalization(), 
            keras.layers.LeakyReLU(), 

            keras.layers.Conv2DTranspose(
                64, 
                (4, 4), 
                strides=(2, 2), 
                padding='same', 
                use_bias=False, 
                kernel_initializer=keras.initializers.RandomNormal(mean=0.0, stddev=0.02), 
            ), 
            keras.layers.BatchNormalization(), 
            keras.layers.LeakyReLU(), 

            keras.layers.Conv2DTranspose(
                3, 
                (4, 4), 
                strides=(2, 2), 
                padding='same', 
                use_bias=False, 
                activation='tanh', 
                kernel_initializer=keras.initializers.RandomNormal(mean=0.0, stddev=0.02), 
            ), 
        ])
        
    def call(self, text, noise_z):
        text = self.text(text)
        text_concat = tf.concat([noise_z, text], axis=1)
        output = self.generator(text_concat)
        return output

    def summary(self):
        text = keras.layers.Input(shape=(self.emb_shape, ), name='text')
        noise_z = keras.layers.Input(shape=(self.input_z_shape, ), name='noise_z')
        model = keras.Model(name='generator', inputs=[text, noise_z], outputs=self.call(text, noise_z))
        return model.summary()

# generator = Generator(Z_DIM, EMB_DIM)
# generator.summary()


# Discriminator

In [None]:
class Discriminator(keras.Model):
    def __init__(self, input_x_shape, emb_shape):
        super().__init__(name='generator')
        self.input_x_shape = input_x_shape
        self.emb_shape = emb_shape
        self.text = keras.Sequential([
            keras.layers.Flatten(), 
            keras.layers.Dense(384), 
            keras.layers.BatchNormalization(), 
            keras.layers.LeakyReLU(), 
        ])
        self.image = keras.Sequential([
            keras.layers.Conv2D(
                64, 
                (4, 4), 
                strides=(2, 2), 
                padding='same', 
                use_bias=False, 
                kernel_initializer=keras.initializers.RandomNormal(mean=0.0, stddev=0.02), 
            ), 
            keras.layers.LeakyReLU(), 

            keras.layers.Conv2D(
                128, 
                (4, 4), 
                strides=(2, 2), 
                padding='same', 
                use_bias=False, 
                kernel_initializer=keras.initializers.RandomNormal(mean=0.0, stddev=0.02), 
            ), 
            keras.layers.LeakyReLU(), 

            keras.layers.Conv2D(
                256, 
                (4, 4), 
                strides=(2, 2), 
                padding='same', 
                use_bias=False, 
                kernel_initializer=keras.initializers.RandomNormal(mean=0.0, stddev=0.02), 
            ), 
            keras.layers.LeakyReLU(), 
            keras.layers.Conv2D(
                512, 
                (4, 4), 
                strides=(2, 2), 
                padding='same', 
                use_bias=False, 
                kernel_initializer=keras.initializers.RandomNormal(mean=0.0, stddev=0.02), 
            ), 
            keras.layers.LeakyReLU(), 
        ])
        self.discriminator = keras.Sequential([
            keras.layers.Conv2D(
                512, 
                (1, 1), 
                strides=(1, 1), 
                padding='same', 
                use_bias=False, 
                kernel_initializer=keras.initializers.RandomNormal(mean=0.0, stddev=0.02), 
            ), 
            keras.layers.LeakyReLU(),
            
            keras.layers.Conv2D(
                1, 
                (4, 4), 
                strides=(1, 1), 
                padding='same', 
                use_bias=False, 
                kernel_initializer=keras.initializers.RandomNormal(mean=0.0, stddev=0.02), 
            ), 
            keras.layers.Flatten(), 
            keras.layers.Dense(1), 
        ])
        
    def call(self, text, img):
        text = self.text(text)
        text = tf.reshape(text, shape=(-1, 4, 4, text.shape[-1] // 16))
        img = self.image(img)
        text_concat = tf.concat([img, text], axis=3)
        output = self.discriminator(text_concat)
        return output

    def summary(self):
        text = keras.layers.Input(shape=(self.emb_shape, ), name='text')
        img = keras.layers.Input(shape=self.input_x_shape, name='img')
        model = keras.Model(name='discriminator', inputs=[text, img], outputs=self.call(text, img))
        return model.summary()

# discriminator = Discriminator(IMAGE_SHAPE, EMB_DIM)
# discriminator.summary()


# parameters
wrap in h{}?

In [None]:
Z_DIM = 128
EMB_DIM = 768
IMAGE_SHAPE = (IMAGE_SIZE, IMAGE_SIZE, 3)
NUM_CAP_PER_IMG = 3

LEARNING_RATE = 1e-4
LAMBDA = 10
EPOCHS = 500
EPOCHS_PER_CKPT = 5
N_CRITIC = 3

SAMPLE_ROW = 3
SAMPLE_COL = 4
SAMPLE_NUM = SAMPLE_ROW * SAMPLE_COL
SAMPLE_DURATION = 10

In [None]:
model_name = 'sentence-transformers/paraphrase-multilingual-mpnet-base-v2'
text_encoder = E2ESentenceTransformer(model_name)

generator = Generator(Z_DIM, EMB_DIM)
generator.summary()

discriminator = Discriminator(IMAGE_SHAPE, EMB_DIM)
discriminator.summary()


# LOSS & OPTIMIze

In [None]:
optimizer_g = keras.optimizers.Adam(learning_rate=LEARNING_RATE, beta_1=0.5)
optimizer_d = keras.optimizers.Adam(learning_rate=LEARNING_RATE, beta_1=0.5)


In [None]:
checkpoint_path = './ckpt-bert3'

ckpt = tf.train.Checkpoint(generator=generator,
                           discriminator=discriminator,
                           optimizer_g=optimizer_g,
                           optimizer_d=optimizer_d)

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=3)

start_epoch = 1
if ckpt_manager.latest_checkpoint:
    ckpt.restore(ckpt_manager.latest_checkpoint)
    latest_ckpt = int(ckpt_manager.latest_checkpoint.split('-')[-1])
    start_epoch = EPOCHS_PER_CKPT * latest_ckpt + 1
    print(f'Restore from latest checkpoint: {ckpt_manager.latest_checkpoint.split("/")[-1]}')
    
print(f'Start epoch: {start_epoch}')


In [None]:
loss_x_list = []
loss_g_list = []
loss_d_list = []


In [None]:
@tf.function
def train_step_g(real_img, caption):
    input_g = tf.random.normal([BATCH_SIZE, Z_DIM])
    text = caption
    
    with tf.GradientTape() as tape_g:
        fake_img = generator(text, input_g, training=True)
        fake_pred = discriminator(text, fake_img, training=True)
        loss_g = -tf.reduce_mean(fake_pred)
        
    gradient_g = tape_g.gradient(loss_g, generator.trainable_variables)
    optimizer_g.apply_gradients(zip(gradient_g, generator.trainable_variables))
    
    return loss_g
    
@tf.function
def train_step_d(real_img, caption):
    input_g = tf.random.normal([BATCH_SIZE, Z_DIM])
    text = caption
    epsilon = tf.random.uniform(shape=[BATCH_SIZE, 1, 1, 1], minval=0, maxval=1)
    
    with tf.GradientTape() as tape_d:
        with tf.GradientTape() as tape_gp:
            fake_img = generator(text, input_g, training=True)
            fake_img_gp = epsilon * real_img + (1 - epsilon) * fake_img
            fake_pred_gp = discriminator(text, fake_img_gp, training=True)
        
        gradient_gp = tape_gp.gradient(fake_pred_gp, fake_img_gp)
        gradient_norm_gp = tf.sqrt(tf.reduce_sum(tf.square(gradient_gp), axis=np.arange(1, len(gradient_gp.shape))))
        gradient_penalty = tf.reduce_mean(tf.square(gradient_norm_gp - 1))
        
        fake_pred = discriminator(text, fake_img, training=True)
        real_pred = discriminator(text, real_img, training=True)
        
        loss_d = tf.reduce_mean(fake_pred) - tf.reduce_mean(real_pred) + LAMBDA * gradient_penalty
    
    gradient_d = tape_d.gradient(loss_d, discriminator.trainable_variables)
    optimizer_d.apply_gradients(zip(gradient_d, discriminator.trainable_variables))
    
    return loss_d


In [None]:
@tf.function
def test_step(caption):
    input_g = tf.random.normal([BATCH_SIZE, Z_DIM])
    text = caption
    fake_img = generator(text, input_g, training=False)
    return fake_img


# Visuala5toin

In [None]:
def generate_img(imgs, row, col, path=None):
    h, w, c = imgs[0].shape
    out = np.zeros((h * row, w * col, c), dtype=np.uint8)
    for n, img in enumerate(imgs):
        j, i = divmod(n, col)
        out[j * h : (j + 1) * h, i * w : (i + 1) * w, :] = img
    if path is not None: 
        imageio.imwrite(path, out)
    return out
  
def generate_gif(imgs_path_list, fname, duration):
    imgs = []
    for img_path in imgs_path_list:
        img = imageio.imread(img_path)
        img_id = img_path.split('/')[-1].split('.')[0]
        
        img_with_id = imageio.core.util.Array(np.concatenate([np.ones((20, img.shape[1], 3), dtype=np.uint8) * 255, img], axis=0))
        img_with_id[:20, :, :] = 0
        img_with_id[:20, :, 0] = 255
        img_with_id[:20, :, 1] = 255
        img_with_id[:20, :, 2] = 255
        img_with_id[:20, :, :] = cv2.putText(img_with_id[:20, :, :], f'EPOCH: {img_id}', (10, 15), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)
        
        imgs.append(img_with_id)
        
    n = float(len(imgs)) / duration
    clip = mpy.VideoClip(lambda t: imgs[int(n * t)], duration=duration)
    clip.write_gif(fname, fps=n)


In [None]:
test_cap = [
    'flower with white long white petals and very long purple stamen ',
    'this medium white flower has rows of thin blue petals and thick stamen ',
    'this flower is white and purple in color with petals that are oval shaped ',
    'this flower is pink and yellow in color with petals that are oval shaped ',
    'the flower has a large bright orange petal with pink anther ',
    'the flower shown has a smooth white petal with patches of yellow as well ',
    'white petals that become yellow as they go to the center where there is an orange stamen ',
    'this flower has bright red petals with green pedicel as its main features ',
    'this flower has the overlapping yellow petals arranged closely toward the center ',
    'this flower has green sepals surrounding several layers of slightly ruffled pink petals ',
    'the pedicel on this flower is purple with a green sepal and rose colored petals ',
    'this white flower has connected circular petals with yellow stamen ',
    'the flower has yellow petals overlapping each other and are yellow in color ',
    'this flower has numerous stamen ringed by multiple layers of thin pink petals ',
    'the petals are broad but thin at the edges with purple tints at the edges and white in the middle ',
    'the yellow flower has petals that are soft smooth and arranged in two layers below the bunch of stamen ',
    'this flower has petals that are pink and yellow with yellow stamen ',
    'red stacked petals surround yellow stamen and a black pistil ',
    'the petals of the flower are in multiple layers and are pink in yellow in color ',
    'this flower has a yellow center and layers of peach colored petals with pointed tips ',
    'this bright pink flower has several fluttery petals and a tubular center ',
    'this flower is white and yellow in color with petals that are rounded at the endges ',
    'the flower has a several pieces of yellow colored petals that looks similar to its leaves ',
    'this flower has several light pink petals and yellow anthers ',
    'this flower is yellow and white in color with petals that are star shaped near the cener ',
    'lavender and white pedal and yellow small flower in the middle of the pedals ',
    'this flower has lavender petals with maroon stripes and brown anther filaments ',
    'this flower has six plain pale yellow petals that alternate with three dark yellow speckled petals ',
    'this flower has petals that are yellow with orange lines ',
    'the flower has petals that are orange with yellow stamen ',
    'this flower has a brown center surrounded by layers of long yellow petals with rounded tips ',
    'this flower is lavender in color with petals that are ruffled and wavy ',
    'this flower is blue in color with petals that have veins ',
    'the petals on this flower are white with yellow stamen ',
    'this flower has petals that are cone shaped and dark purple ',
    'this flower is purple and white in color and has petals that are multi colored ',
    'a large group of bells that are blue on this flower ',
    'this flower is bright purple with many pedals that are roundish whth pale white outer petals ',
    'this flower has large yellow petals and long yellow stamen on it ',
    'this flower has spiky blue petals and a spiky black stigma on it ',
    'this flower is purple and yellow in color with petals that are oval shaped ',
    'the flower has petals that are large and pink with yellow anther',
]


In [None]:
# sample gen???

# Training

In [None]:
start = time.time()

sample_z = tf.random.normal([SAMPLE_NUM, Z_DIM])
sample_cap = test_cap[:SAMPLE_NUM]

for epoch in range(start_epoch, EPOCHS + 1):
    
    loss_g = 0
    loss_d = 0
    
    # Train step
    for step, (img, cap) in tqdm(enumerate(dataset), total=num_steps):
        if step % (N_CRITIC + 1) == N_CRITIC - 1:
            loss_g += train_step_g(img, cap)
            # Store data to list
            loss_x_list.append(epoch + step / num_steps)
            loss_g_list.append(loss_g / 1)
            loss_d_list.append(loss_d / N_CRITIC)
            loss_g = 0
            loss_d = 0
        else:
            loss_d += train_step_d(img, cap)
    
    # Save checkpoint
    if epoch % EPOCHS_PER_CKPT == 0:
        ckpt_manager.save()
    
    # Generate sample image
    sample_text = text_encoder(sample_cap)
    sample_imgs = generator(sample_text, sample_z, training=False)
    sample_imgs = tf.clip_by_value((sample_imgs + 1) / 2 * 255, 0, 255)
    img = generate_img(sample_imgs, SAMPLE_ROW, SAMPLE_COL, f'imgs/{epoch:>04d}.png')
    
    # Display sample image
    if epoch % 5 == 0:
        plt.imshow(img)
        plt.axis("off")
        plt.show()
    
    print(f'Epoch {epoch}: Loss_g {loss_g_list[-1]:.5f}, Loss_d {loss_d_list[-1]:.5f}')
    
print (f'Time taken for {EPOCHS} epoch: {time.time() - start} sec')


# Evaluation

# Testing Dataset

In [None]:
def test_preprocess_data(index, caption):
    caption = tf.cast(caption, tf.float32)
    return index, caption

def test_dataset_generator(pkl_file):
    df = pd.read_pickle(pkl_file)
    captions = df['Captions'].values
    caption = np.asarray(captions)
#     caption = caption.astype(np.int32)
    caption = [caption2str(x) for x in caption]
    caption = tf.concat(
        [text_encoder(caption[64 * i:64 * min(len(caption), i + 1)]) \
             for i in range((len(caption) + 63) // 64)], 
        axis=0)
    caption = caption.numpy()
    caption = np.asarray(caption)
    index = df['ID'].values
    index = np.asarray(index)
    
    assert caption.shape[0] == index.shape[0]
    
    dataset = tf.data.Dataset.from_tensor_slices((index, caption))
    dataset = dataset.map(test_preprocess_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    dataset = dataset.repeat().batch(BATCH_SIZE, drop_remainder=True)
    
    return dataset


In [None]:

test_dataset = test_dataset_generator(data_path + '/testData.pkl')


In [None]:

data = pd.read_pickle(data_path + '/testData.pkl')
captions = data['Captions'].values
NUM_TEST = len(captions)
EPOCH_TEST = int(NUM_TEST / BATCH_SIZE)


# Inference

In [None]:
output_dir = './inference/demo'

def inference(dataset):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
        
    step = 0
    start = time.time()
    for idx, captions in dataset:
        if step > EPOCH_TEST:
            break
        
        fake_image = test_step(captions)
        step += 1
        for i in range(BATCH_SIZE):
            img = np.clip(fake_image[i].numpy() * 0.5 + 0.5, 0.0, 1.0)
            plt.imsave(output_dir + '/inference_{:04d}.jpg'.format(idx[i]), img)
            
    print('Time for inference is {:.4f} sec'.format(time.time() - start))
