In [1]:
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
from tensorflow import keras
from keras import layers
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 
import string
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import PIL
import random
import time
from pathlib import Path

import re
from IPython import display
import cv2

In [2]:
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        # Restrict TensorFlow to only use the first GPU
        tf.config.experimental.set_visible_devices(gpus[0], 'GPU')

        # Currently, memory growth needs to be the same across GPUs
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        logical_gpus = tf.config.experimental.list_logical_devices('GPU')
        print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    except RuntimeError as e:
        # Memory growth must be set before GPUs have been initialized
        print(e)

1 Physical GPUs, 1 Logical GPUs


## Preprocess text

In [3]:
dictionary_path = './2024-datalab-cup3-reverse-image-caption/dictionary'
vocab = np.load(dictionary_path + '/vocab.npy')
print('there are {} vocabularies in total'.format(len(vocab)))

word2Id_dict = dict(np.load(dictionary_path + '/word2Id.npy'))
id2word_dict = dict(np.load(dictionary_path + '/id2Word.npy'))
print('Word to id mapping, for example: %s -> %s' % ('flower', word2Id_dict['flower']))
print('Id to word mapping, for example: %s -> %s' % ('1', id2word_dict['1']))
print('Tokens: <PAD>: %s; <RARE>: %s' % (word2Id_dict['<PAD>'], word2Id_dict['<RARE>']))

there are 5427 vocabularies in total
Word to id mapping, for example: flower -> 1
Id to word mapping, for example: 1 -> flower
Tokens: <PAD>: 5427; <RARE>: 5428


In [4]:
def sent2IdList(line, MAX_SEQ_LENGTH=20):
    MAX_SEQ_LIMIT = MAX_SEQ_LENGTH
    padding = 0
    
    # data preprocessing, remove all puntuation in the texts
    prep_line = re.sub('[%s]' % re.escape(string.punctuation), ' ', line.rstrip())
    prep_line = prep_line.replace('-', ' ')
    prep_line = prep_line.replace('-', ' ')
    prep_line = prep_line.replace('  ', ' ')
    prep_line = prep_line.replace('.', '')
    tokens = prep_line.split(' ')
    tokens = [
        tokens[i] for i in range(len(tokens))
        if tokens[i] != ' ' and tokens[i] != ''
    ]
    l = len(tokens)
    padding = MAX_SEQ_LIMIT - l
    
    # make sure length of each text is equal to MAX_SEQ_LENGTH, and replace the less common word with <RARE> token
    for i in range(padding):
        tokens.append('<PAD>')
    line = [
        word2Id_dict[tokens[k]]
        if tokens[k] in word2Id_dict else word2Id_dict['<RARE>']
        for k in range(len(tokens))
    ]

    return line

text = "the flower shown has yellow anther red pistil and bright red petals."
print(text)
print(sent2IdList(text))


the flower shown has yellow anther red pistil and bright red petals.
['9', '1', '82', '5', '11', '70', '20', '31', '3', '29', '20', '2', '5427', '5427', '5427', '5427', '5427', '5427', '5427', '5427']


## Dataset

In [6]:
data_path = './2024-datalab-cup3-reverse-image-caption/dataset'
df = pd.read_pickle(data_path + '/text2ImgData.pkl')
num_training_sample = len(df)
n_images_train = num_training_sample
print('There are %d image in training data' % (n_images_train))

There are 7370 image in training data


In [7]:
df.head(5)

Unnamed: 0_level_0,Captions,ImagePath
ID,Unnamed: 1_level_1,Unnamed: 2_level_1
6734,"[[9, 2, 17, 9, 1, 6, 14, 13, 18, 3, 41, 8, 11,...",./102flowers/image_06734.jpg
6736,"[[4, 1, 5, 12, 2, 3, 11, 31, 28, 68, 106, 132,...",./102flowers/image_06736.jpg
6737,"[[9, 2, 27, 4, 1, 6, 14, 7, 12, 19, 5427, 5427...",./102flowers/image_06737.jpg
6738,"[[9, 1, 5, 8, 54, 16, 38, 7, 12, 116, 325, 3, ...",./102flowers/image_06738.jpg
6739,"[[4, 12, 1, 5, 29, 11, 19, 7, 26, 70, 5427, 54...",./102flowers/image_06739.jpg


## Create dataset by dataset api

In [8]:
# in this competition, you have to generate image in size 64x64x3
IMAGE_HEIGHT = 64
IMAGE_WIDTH = 64
IMAGE_CHANNEL = 3

def training_data_generator(text_embeddings, image_path):
    # Load the image according to the image path
    img = tf.io.read_file(image_path)
    img = tf.image.decode_image(img, channels=3)
    img = tf.image.convert_image_dtype(img, tf.float32)
    img.set_shape([None, None, 3])
    img = tf.image.resize(img, size=[IMAGE_HEIGHT, IMAGE_WIDTH])
    img.set_shape([IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNEL])
    text_embeddings = tf.cast(text_embeddings, tf.float32)
    return img, text_embeddings


def dataset_generator(filenames, batch_size, data_generator):
    # Load the training data into two lists: image paths and captions
    df = pd.read_pickle(filenames)
    all_image_paths = []
    all_text_embeddings = []

    # Expand each image path to match the number of captions
    for i in range(len(df)):
        image_path = df[i][0]
        # text embeddings is a array with shape (512,)
        text_embeddings = df[i][1]
        all_image_paths.append(image_path)
        all_text_embeddings.append(text_embeddings)

    # Convert to NumPy arrays
    all_text_embeddings = np.asarray(all_text_embeddings)  # Store captions as strings
    all_image_paths = np.asarray(all_image_paths)
    
    # Ensure alignment between captions and images
    assert len(all_text_embeddings) == len(all_image_paths)
    
    # Create a TensorFlow dataset
    dataset = tf.data.Dataset.from_tensor_slices((all_text_embeddings, all_image_paths))
    dataset = dataset.map(data_generator, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    dataset = dataset.shuffle(len(all_text_embeddings)).batch(batch_size, drop_remainder=True)
    dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

    return dataset

In [10]:
BATCH_SIZE = 32
dataset = dataset_generator('./2024-datalab-cup3-reverse-image-caption/dataset/image_text_pairs.pkl', BATCH_SIZE, training_data_generator)

In [None]:
# iter the next data in the dataset
data = next(iter(dataset))
print('Image batch shape:', data[0].shape)
print('Caption batch shape:', data[1].shape)

## StackGAN

### KL divergence

In [10]:
from keras import backend as K
def KL_loss(y_true, y_pred):
  mean = y_pred[:, :128]
  logsigma = y_pred[:, 128:]
  loss = -logsigma + 0.5*(-1 + K.exp(2.0*logsigma) + K.square(mean))
  loss = K.mean(loss)
  return loss

### stage 1 GAN

In [11]:
class ConditioningAugmentation(tf.keras.Model):
  def __init__(self):
    super(ConditioningAugmentation, self).__init__()
    self.dense = tf.keras.layers.Dense(units = 256)

  def call(self, E):
    X = self.dense(E)
    phi = tf.nn.leaky_relu(X)
    mean = phi[:, :128]
    std = K.exp(phi[:, 128:])
    epsilon = K.random_normal(shape = K.constant((mean.shape[1], ), dtype = 'int32'))
    C = mean + epsilon*std
    return C, phi

In [12]:
class EmbeddingCompressor(tf.keras.Model):
  def __init__(self):
    super(EmbeddingCompressor, self).__init__()
    self.dense = tf.keras.layers.Dense(units = 128)

  def call(self, E):
    X = self.dense(E)
    return tf.nn.relu(X)

In [None]:
class Stage1Generator(tf.keras.Model):
    def __init__(self):
        super(Stage1Generator, self).__init__()
        self.canet = ConditioningAugmentation()
        self.concat = tf.keras.layers.Concatenate(axis=1)
        self.dense = tf.keras.layers.Dense(units=128 * 8 * 4 * 4, kernel_initializer=tf.random_normal_initializer(stddev=0.02))
        self.reshape = tf.keras.layers.Reshape(target_shape=(4, 4, 128 * 8), input_shape=(128 * 8 * 4 * 4,))
        self.batchnorm1 = tf.keras.layers.BatchNormalization(axis=-1, momentum=0.99)
        self.deconv1 = tf.keras.layers.Conv2DTranspose(filters=512, kernel_size=4, strides=(2, 2), padding="same", kernel_initializer=tf.random_normal_initializer(stddev=0.02))
        self.batchnorm2 = tf.keras.layers.BatchNormalization(axis=-1, momentum=0.99)
        self.deconv2 = tf.keras.layers.Conv2DTranspose(filters=256, kernel_size=4, strides=(2, 2), padding="same", kernel_initializer=tf.random_normal_initializer(stddev=0.02))
        self.batchnorm3 = tf.keras.layers.BatchNormalization(axis=-1, momentum=0.99)
        self.deconv3 = tf.keras.layers.Conv2DTranspose(filters=128, kernel_size=4, strides=(2, 2), padding="same", kernel_initializer=tf.random_normal_initializer(stddev=0.02))
        self.batchnorm4 = tf.keras.layers.BatchNormalization(axis=-1, momentum=0.99)
        self.deconv4 = tf.keras.layers.Conv2DTranspose(filters=3, kernel_size=4, strides=(2, 2), padding="same", kernel_initializer=tf.random_normal_initializer(stddev=0.02))

    def call(self, inputs):
        E, Z = inputs
        C, phi = self.canet(E)

        gen_input = self.concat([C, Z])
        
        X = self.dense(gen_input)
        X = self.reshape(X)
        X = self.batchnorm1(X)
        X = tf.nn.relu(X)

        X = self.deconv1(X)
        X = self.batchnorm2(X)
        X = tf.nn.relu(X)

        X = self.deconv2(X)
        X = self.batchnorm3(X)
        X = tf.nn.relu(X)

        X = self.deconv3(X)
        X = self.batchnorm4(X)
        X = tf.nn.relu(X)

        X = self.deconv4(X)
        return tf.nn.tanh(X), phi


In [14]:
class Stage1Discriminator(tf.keras.Model):
  def __init__(self):
    super(Stage1Discriminator, self).__init__()
    self.conv1 = tf.keras.layers.Conv2D(filters = 64, kernel_size = 4, strides = 2, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.conv2 = tf.keras.layers.Conv2D(filters = 128, kernel_size = 4, strides = 2, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.batchnorm1 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.conv3 = tf.keras.layers.Conv2D(filters = 256, kernel_size = 4, strides = 2, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.batchnorm2 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.conv4 = tf.keras.layers.Conv2D(filters = 512, kernel_size = 4, strides = 2, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.batchnorm3 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.embed = EmbeddingCompressor()
    self.reshape = tf.keras.layers.Reshape(target_shape = (1, 1, 128))
    self.concat = tf.keras.layers.Concatenate()
    self.conv5 = tf.keras.layers.Conv2D(filters = 64*8, kernel_size = 1, strides = 1, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.batchnorm4 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.conv6 = tf.keras.layers.Conv2D(filters = 1, kernel_size = 4, strides = 1, padding = "valid", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))

  def call(self, inputs):
    I, E = inputs
    X = self.conv1(I)
    X = tf.nn.leaky_relu(X)

    X = self.conv2(X)
    X = self.batchnorm1(X)
    X = tf.nn.leaky_relu(X)

    X = self.conv3(X)
    X = self.batchnorm2(X)
    X = tf.nn.leaky_relu(X)

    X = self.conv4(X)
    X = self.batchnorm3(X)
    X = tf.nn.leaky_relu(X)

    T = self.embed(E)
    T = self.reshape(T)
    T = tf.tile(T, (1, 4, 4, 1))
    merged_input = self.concat([X, T])

    Y = self.conv5(merged_input)
    Y = self.batchnorm4(Y)
    Y = tf.nn.leaky_relu(Y)

    Y = self.conv6(Y)
    return tf.squeeze(Y)

In [17]:
class Stage1Model(tf.keras.Model):
  def __init__(self):
    super(Stage1Model, self).__init__()
    self.stage1_generator = Stage1Generator()
    self.stage1_discriminator = Stage1Discriminator()

  def train(self, train_ds, batch_size = 32, num_epochs = 600, z_dim = 100, c_dim = 128, stage1_generator_lr = 0.0004, stage1_discriminator_lr = 0.0004):
    generator_optimizer = tf.keras.optimizers.Adam(lr = stage1_generator_lr, beta_1 = 0.5, beta_2 = 0.999)
    discriminator_optimizer = tf.keras.optimizers.Adam(lr = stage1_discriminator_lr, beta_1 = 0.5, beta_2 = 0.999)
    
    for epoch in range(num_epochs):
      print("Epoch %d/%d:\n ["%(epoch + 1, num_epochs), end = "")
      start_time = time.time()
      if epoch % 100 == 0:
        K.set_value(generator_optimizer.learning_rate, generator_optimizer.learning_rate / 2)
        K.set_value(discriminator_optimizer.learning_rate, discriminator_optimizer.learning_rate / 2)
    
      generator_loss_log = []
      discriminator_loss_log = []
      steps_per_epoch = 125
      batch_iter = iter(train_ds)
      for i in range(steps_per_epoch):
        if i % 5 == 0:
          print("=", end = "")
        image_batch, embedding_batch = next(batch_iter)
        z_noise = tf.random.normal((batch_size, z_dim))

        mismatched_images = tf.roll(image_batch, shift = 1, axis = 0)

        real_labels = tf.random.uniform(shape = (batch_size, ), minval = 0.9, maxval = 1.0)
        fake_labels = tf.random.uniform(shape = (batch_size, ), minval = 0.0, maxval = 0.1)
        mismatched_labels = tf.random.uniform(shape = (batch_size, ), minval = 0.0, maxval = 0.1)

        with tf.GradientTape() as generator_tape, tf.GradientTape() as discriminator_tape:
          fake_images, phi = self.stage1_generator([embedding_batch, z_noise])
          
          real_logits = self.stage1_discriminator([image_batch, embedding_batch])
          fake_logits = self.stage1_discriminator([fake_images, embedding_batch])
          mismatched_logits = self.stage1_discriminator([mismatched_images, embedding_batch])
          
          l_sup = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(real_labels, fake_logits))
          l_klreg = KL_loss(tf.random.normal((phi.shape[0], phi.shape[1])), phi)
          generator_loss = l_sup + 2.0*l_klreg
          
          l_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(real_labels, real_logits))
          l_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(fake_labels, fake_logits))
          l_mismatched = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(mismatched_labels, mismatched_logits))
          discriminator_loss = 0.5*tf.add(l_real, 0.5*tf.add(l_fake, l_mismatched))
        
        generator_gradients = generator_tape.gradient(generator_loss, self.stage1_generator.trainable_variables)
        discriminator_gradients = discriminator_tape.gradient(discriminator_loss, self.stage1_discriminator.trainable_variables)
        
        generator_optimizer.apply_gradients(zip(generator_gradients, self.stage1_generator.trainable_variables))
        discriminator_optimizer.apply_gradients(zip(discriminator_gradients, self.stage1_discriminator.trainable_variables))
        
        generator_loss_log.append(generator_loss)
        discriminator_loss_log.append(discriminator_loss)

      end_time = time.time()

      if epoch % 1 == 0:
        epoch_time = end_time - start_time
        template = "] - generator_loss: {:.4f} - discriminator_loss: {:.4f} - epoch_time: {:.2f} s"
        print(template.format(tf.reduce_mean(generator_loss_log), tf.reduce_mean(discriminator_loss_log), epoch_time))

      if (epoch + 1) % 10 == 0 or epoch == num_epochs - 1:
        save_path = "./Text to Image/lr_results/epoch_" + str(epoch + 1)
        temp_embeddings = None
        for _, embeddings in train_ds:
          temp_embeddings = embeddings.numpy()
          break
        if os.path.exists(save_path) == False:
          os.makedirs(save_path)
        temp_batch_size = 10
        temp_z_noise = tf.random.normal((temp_batch_size, z_dim))
        temp_embedding_batch = temp_embeddings[0:temp_batch_size]
        fake_images, _ = self.stage1_generator([temp_embedding_batch, temp_z_noise])
        for i, image in enumerate(fake_images):
          image = 127.5*image + 127.5
          image = image.numpy().astype('uint8')
          cv2.imwrite(save_path + "/gen_%d.png"%(i), image)
      
        self.stage1_generator.save_weights("./Text to Image/lr_model_checkpoints/stage1_generator_" + str(epoch + 1) + ".ckpt")
        self.stage1_discriminator.save_weights("./Text to Image/lr_model_checkpoints/stage1_discriminator_" + str(epoch + 1) + ".ckpt")

    def generate_image(self, embedding):
      self.stage1_generator.compile(loss = "mse", optimizer = "adam")
      self.stage1_generator.load_weights("./Text to Image/lr_model_checkpoints/stage1_generator_600.ckpt").expect_partial()
      z_noise = tf.random.normal((batch_size, z_dim))
      generated_image = self.stage1_generator([embedding, z_noise])
      return generated_image

In [None]:
model = Stage1Model()
model.train(dataset)

Epoch 1/600:
 [=Concatenated input shape: (32, 228)
Dense layer output shape: (32, 16384)
Reshaped output shape: (32, 4, 4, 1024)
After deconv1: (32, 8, 8, 512)
After deconv2: (32, 16, 16, 256)
After deconv3: (32, 32, 32, 128)
Concatenated input shape: (32, 228)
Dense layer output shape: (32, 16384)
Reshaped output shape: (32, 4, 4, 1024)
After deconv1: (32, 8, 8, 512)
After deconv2: (32, 16, 16, 256)
After deconv3: (32, 32, 32, 128)
Concatenated input shape: (32, 228)
Dense layer output shape: (32, 16384)
Reshaped output shape: (32, 4, 4, 1024)
After deconv1: (32, 8, 8, 512)
After deconv2: (32, 16, 16, 256)
After deconv3: (32, 32, 32, 128)
Concatenated input shape: (32, 228)
Dense layer output shape: (32, 16384)
Reshaped output shape: (32, 4, 4, 1024)
After deconv1: (32, 8, 8, 512)
After deconv2: (32, 16, 16, 256)
After deconv3: (32, 32, 32, 128)
Concatenated input shape: (32, 228)
Dense layer output shape: (32, 16384)
Reshaped output shape: (32, 4, 4, 1024)
After deconv1: (32, 8, 8, 

### stage 2 GAN

In [None]:
class ResidualBlock(tf.keras.layers.Layer):
  def __init__(self):
    super(ResidualBlock, self).__init__()
    self.conv1 = tf.keras.layers.Conv2D(filters = 128*4, kernel_size = 3, strides = 1, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.batchnorm1 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.conv2 = tf.keras.layers.Conv2D(filters = 128*4, kernel_size = 3, strides = 1, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.batchnorm2 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)

    def call(self, I):
      X = self.conv1(I)
      X = self.batchnorm1(X)
      X = tf.nn.relu(X)

      X = self.conv2(X)
      X = self.batchnorm2(X)
      X = tf.nn.relu(X)
      X = tf.keras.layers.Add()([X, I])
      X = tf.nn.relu(X)
      return X

In [None]:
class Stage2Generator(tf.keras.Model):
  def __init__(self):
    super(Stage2Generator, self).__init__()
    self.canet = ConditioningAugmentation()
    self.conv1 = tf.keras.layers.Conv2D(128, kernel_size = 3, strides = 1, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.conv2 = tf.keras.layers.Conv2D(256, kernel_size = 4, strides = 2, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.batchnorm1 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.conv3 = tf.keras.layers.Conv2D(512, kernel_size = 4, strides = 2, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.batchnorm2 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.conv4 = tf.keras.layers.Conv2D(512, kernel_size = 3, strides = 1, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.batchnorm3 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.resblock1 = ResidualBlock()
    self.resblock2 = ResidualBlock()
    self.resblock3 = ResidualBlock()
    self.resblock4 = ResidualBlock()
    self.upsamp1 = tf.keras.layers.UpSampling2D(size = (2, 2))
    self.conv5 = tf.keras.layers.Conv2D(256, kernel_size = 3, strides = 1, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.batchnorm4 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.upsamp2 = tf.keras.layers.UpSampling2D(size = (2, 2))
    self.conv6 = tf.keras.layers.Conv2D(filters = 128, kernel_size = 3, strides = 1, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.batchnorm5 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.upsamp3 = tf.keras.layers.UpSampling2D(size = (2, 2))
    self.conv7 = tf.keras.layers.Conv2D(filters = 64, kernel_size = 3, strides = 1, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.batchnorm6 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.upsamp4 = tf.keras.layers.UpSampling2D(size = (2, 2))
    self.conv8 = tf.keras.layers.Conv2D(filters = 32, kernel_size = 3, strides = 1, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.batchnorm7 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.conv9 = tf.keras.layers.Conv2D(filters = 3, kernel_size = 3, strides = 1, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
  
  def call(self, inputs):
    E, I = inputs
    C, phi = self.canet(E)

    X = self.conv1(I)
    X = tf.nn.relu(X)
    
    X = self.conv2(X)
    X = self.batchnorm1(X)
    X = tf.nn.relu(X)

    X = self.conv3(X)
    X = self.batchnorm2(X)
    X = tf.nn.relu(X)

    C = K.expand_dims(C, axis = 1)
    C = K.expand_dims(C, axis = 1)
    C = K.tile(C, [1, 16, 16, 1])
    J = K.concatenate([C, X], axis = 3)

    X = self.conv4(X)
    X = self.batchnorm3(X)
    X = tf.nn.relu(X)

    X = self.resblock1(X)
    X = self.resblock2(X)
    X = self.resblock3(X)
    X = self.resblock4(X)

    X = self.upsamp1(X)
    X = self.conv5(X)
    X = self.batchnorm4(X)
    X = tf.nn.relu(X)
    
    X = self.upsamp2(X)
    X = self.conv6(X)
    X = self.batchnorm5(X)
    X = tf.nn.relu(X)
    
    X = self.upsamp3(X)
    X = self.conv7(X)
    X = self.batchnorm6(X)
    X = tf.nn.relu(X)
    
    X = self.upsamp4(X)
    X = self.conv8(X)
    X = self.batchnorm7(X)
    X = tf.nn.relu(X)
    
    X = self.conv9(X)
    return tf.nn.tanh(X), phi
     

In [None]:
class Stage2Discriminator(tf.keras.Model):
  def __init__(self):
    super(Stage2Discriminator, self).__init__()
    self.embed = EmbeddingCompressor()
    self.reshape = tf.keras.layers.Reshape(target_shape = (1, 1, 128))
    self.conv1 = tf.keras.layers.Conv2D(filters = 64, kernel_size = 4, strides = 2, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.conv2 = tf.keras.layers.Conv2D(filters = 128, kernel_size = 4, strides = 2, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.batchnorm1 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.conv3 = tf.keras.layers.Conv2D(filters = 256, kernel_size = 4, strides = 2, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.batchnorm2 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.conv4 = tf.keras.layers.Conv2D(filters = 512, kernel_size = 4, strides = 2, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.batchnorm3 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.conv5 = tf.keras.layers.Conv2D(filters = 1024, kernel_size = 4, strides = 2, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.batchnorm4 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.conv6 = tf.keras.layers.Conv2D(filters = 2048, kernel_size = 4, strides = 2, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.batchnorm5 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.conv7 = tf.keras.layers.Conv2D(filters = 1024, kernel_size = 1, strides = 1, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.batchnorm6 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.conv8 = tf.keras.layers.Conv2D(filters = 512, kernel_size = 1, strides = 1, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.batchnorm7 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.conv9 = tf.keras.layers.Conv2D(filters = 128, kernel_size = 1, strides = 1, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.batchnorm8 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.conv10 = tf.keras.layers.Conv2D(filters = 128, kernel_size = 3, strides = 1, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.batchnorm9 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.conv11 = tf.keras.layers.Conv2D(filters = 512, kernel_size = 3, strides = 1, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.batchnorm10 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.conv12 = tf.keras.layers.Conv2D(filters = 64*8, kernel_size = 1, strides = 1, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.batchnorm11 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.conv13 = tf.keras.layers.Conv2D(filters = 1, kernel_size = 4, strides = 1, padding = "valid", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))

  def call(self, inputs):
    I, E = inputs
    T = self.embed(E)
    T = self.reshape(T)
    T = tf.tile(T, (1, 4, 4, 1))

    X = self.conv1(I)
    X = tf.nn.leaky_relu(X)

    X = self.conv2(X)
    X = self.batchnorm1(X)
    X = tf.nn.leaky_relu(X)
    
    X = self.conv3(X)
    X = self.batchnorm2(X)
    X = tf.nn.leaky_relu(X)
    
    X = self.conv4(X)
    X = self.batchnorm3(X)
    X = tf.nn.leaky_relu(X)
    
    X = self.conv5(X)
    X = self.batchnorm4(X)
    X = tf.nn.leaky_relu(X)
   
    X = self.conv6(X)
    X = self.batchnorm5(X)
    X = tf.nn.leaky_relu(X)
    
    X = self.conv7(X)
    X = self.batchnorm6(X)
    X = tf.nn.leaky_relu(X)
    
    X = self.conv8(X)
    X = self.batchnorm7(X)

    Y = self.conv9(X)
    Y = self.batchnorm8(Y)
    Y = tf.nn.leaky_relu(Y)

    Y = self.conv10(Y)
    Y = self.batchnorm9(Y)
    Y = tf.nn.leaky_relu(Y)

    Y = self.conv11(Y)
    Y = self.batchnorm10(Y)

    A = tf.keras.layers.Add()([X, Y])
    A = tf.nn.leaky_relu(A)

    merged_input = tf.keras.layers.concatenate([A, T])

    Z = self.conv12(merged_input)
    Z = self.batchnorm11(Z)
    Z = tf.nn.leaky_relu(Z)
    
    Z = self.conv13(Z)
    return tf.squeeze(Z)

In [None]:
class Stage2Model(tf.keras.Model):
  def __init__(self):
    super(Stage2Model, self).__init__()
    self.stage1_generator = Stage1Generator()
    self.stage1_generator.compile(loss = "mse", optimizer = "adam")
    self.stage1_generator.load_weights("./Text to Image/lr_model_checkpoints/stage1_generator_600.ckpt").expect_partial()
    
    self.stage2_generator = Stage2Generator()
    self.stage2_discriminator = Stage2Discriminator()
    self.stage2_generator.compile(loss = "mse", optimizer = "adam")
    self.stage2_generator.load_weights("./Text to Image/hr_model_checkpoints/stage2_generator_170.ckpt").expect_partial()
    self.stage2_discriminator.compile(loss = "mse", optimizer = "adam")
    self.stage2_discriminator.load_weights("./Text to Image/hr_model_checkpoints/stage2_discriminator_170.ckpt").expect_partial()
    
  def train(self, train_ds, batch_size = 64, num_epochs = 1200, z_dim = 100, stage1_generator_lr = 0.0001, stage1_discriminator_lr = 0.0001):
    generator_optimizer = tf.keras.optimizers.Adam(lr = stage1_generator_lr, beta_1 = 0.5, beta_2 = 0.999)
    discriminator_optimizer = tf.keras.optimizers.Adam(lr = stage1_discriminator_lr, beta_1 = 0.5, beta_2 = 0.999)
    
    for epoch in range(num_epochs):
      print("Epoch %d/%d:\n ["%(epoch + 1, num_epochs), end = "")
      start_time = time.time()
      if epoch % 100 == 0:
        K.set_value(generator_optimizer.learning_rate, generator_optimizer.learning_rate / 2)
        K.set_value(discriminator_optimizer.learning_rate, discriminator_optimizer.learning_rate / 2)
    
      generator_loss_log = []
      discriminator_loss_log = []
      steps_per_epoch = 125
      batch_iter = iter(train_ds)
      for i in range(steps_per_epoch):
        if i % 5 == 0:
          print("=", end = "")
        hr_image_batch, embedding_batch = next(batch_iter)
        z_noise = tf.random.normal((batch_size, z_dim))

        mismatched_images = tf.roll(hr_image_batch, shift = 1, axis = 0)

        real_labels = tf.random.uniform(shape = (batch_size, ), minval = 0.9, maxval = 1.0)
        fake_labels = tf.random.uniform(shape = (batch_size, ), minval = 0.0, maxval = 0.1)
        mismatched_labels = tf.random.uniform(shape = (batch_size, ), minval = 0.0, maxval = 0.1)

        with tf.GradientTape() as generator_tape, tf.GradientTape() as discriminator_tape:
          lr_fake_images, _ = self.stage1_generator([embedding_batch, z_noise])
          hr_fake_images, phi = self.stage2_generator([embedding_batch, lr_fake_images])
          real_logits = self.stage2_discriminator([hr_image_batch, embedding_batch])
          fake_logits = self.stage2_discriminator([hr_fake_images, embedding_batch])
          mismatched_logits = self.stage2_discriminator([mismatched_images, embedding_batch])
          
          l_sup = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(real_labels, fake_logits))
          l_klreg = KL_loss(tf.random.normal((phi.shape[0], phi.shape[1])), phi)
          generator_loss = l_sup + 2.0*l_klreg
          
          l_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(real_labels, real_logits))
          l_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(fake_labels, fake_logits))
          l_mismatched = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(mismatched_labels, mismatched_logits))
          discriminator_loss = 0.5*tf.add(l_real, 0.5*tf.add(l_fake, l_mismatched))
        
        generator_gradients = generator_tape.gradient(generator_loss, self.stage2_generator.trainable_variables)
        discriminator_gradients = discriminator_tape.gradient(discriminator_loss, self.stage2_discriminator.trainable_variables)
        
        generator_optimizer.apply_gradients(zip(generator_gradients, self.stage2_generator.trainable_variables))
        discriminator_optimizer.apply_gradients(zip(discriminator_gradients, self.stage2_discriminator.trainable_variables))
        
        generator_loss_log.append(generator_loss)
        discriminator_loss_log.append(discriminator_loss)
        
      end_time = time.time()

      if epoch % 1 == 0:
        epoch_time = end_time - start_time
        template = "] - generator_loss: {:.4f} - discriminator_loss: {:.4f} - epoch_time: {:.2f} s"
        print(template.format(tf.reduce_mean(generator_loss_log), tf.reduce_mean(discriminator_loss_log), epoch_time))

      if (epoch + 1) % 10 == 0 or epoch == num_epochs - 1:
        save_path = "./Text to Image/hr_results/epoch_" + str(epoch + 1)
        temp_embeddings = None
        for _, embeddings in train_ds:
          temp_embeddings = embeddings.numpy()
          break
        if os.path.exists(save_path) == False:
          os.makedirs(save_path)
        temp_batch_size = 10
        temp_z_noise = tf.random.normal((temp_batch_size, z_dim))
        temp_embedding_batch = temp_embeddings[0:temp_batch_size]
        fake_images, _ = self.stage1_generator([temp_embedding_batch, temp_z_noise])
        for i, image in enumerate(fake_images):
          image = 127.5*image + 127.5
          image = image.numpy().astype('uint8')
          cv2.imwrite(save_path + "/gen_%d.png"%(i), image)
        self.stage2_generator.save_weights("./Text to Image/hr_model_checkpoints/stage2_generator_" + str(epoch + 1) + ".ckpt")
        self.stage2_discriminator.save_weights("./Text to Image/hr_model_checkpoints/stage2_discriminator_" + str(epoch + 1) + ".ckpt")

In [None]:
IMAGE_HEIGHT = 256
IMAGE_WIDTH = 256
dataset = dataset_generator(data_path + '/text2ImgData.pkl', BATCH_SIZE, training_data_generator)
stage2_model = Stage2Model()
stage2_model.train(dataset)

## Visualization

In [62]:
def merge(images, size):
    h, w = images.shape[1], images.shape[2]
    img = np.zeros((h * size[0], w * size[1], 3))
    for idx, image in enumerate(images):
        i = idx % size[1]
        j = idx // size[1]
        img[j*h:j*h+h, i*w:i*w+w, :] = image
    return img

def imsave(images, size, path):
    # getting the pixel values between [0, 1] to save it
    return plt.imsave(path, merge(images, size)*0.5 + 0.5)

def save_images(images, size, image_path):
    return imsave(images, size, image_path)

In [63]:
def sample_generator(caption, batch_size):
    caption = np.asarray(caption)
    caption = caption.astype(np.int)
    dataset = tf.data.Dataset.from_tensor_slices(caption)
    dataset = dataset.batch(batch_size)
    return dataset

In [64]:
ni = int(np.ceil(np.sqrt(hparas['BATCH_SIZE'])))
sample_size = hparas['BATCH_SIZE']
sample_seed = np.random.normal(loc=0.0, scale=1.0, size=(sample_size, hparas['Z_DIM'])).astype(np.float32)
sample_sentence = ["the flower shown has yellow anther red pistil and bright red petals."] * int(sample_size/ni) + \
                  ["this flower has petals that are yellow, white and purple and has dark lines"] * int(sample_size/ni) + \
                  ["the petals on this flower are white with a yellow center"] * int(sample_size/ni) + \
                  ["this flower has a lot of small round pink petals."] * int(sample_size/ni) + \
                  ["this flower is orange in color, and has petals that are ruffled and rounded."] * int(sample_size/ni) + \
                  ["the flower has yellow petals and the center of it is brown."] * int(sample_size/ni) + \
                  ["this flower has petals that are blue and white."] * int(sample_size/ni) +\
                  ["these white flowers have petals that start off white in color and end in a white towards the tips."] * int(sample_size/ni)

for i, sent in enumerate(sample_sentence):
    sample_sentence[i] = sent2IdList(sent)
sample_sentence = sample_generator(sample_sentence, hparas['BATCH_SIZE'])

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  This is separate from the ipykernel package so we can avoid doing imports until


## Training

In [65]:
if not os.path.exists('samples/demo'):
    os.makedirs('samples/demo')

In [66]:
def train(dataset, epochs):
    # hidden state of RNN
    hidden = text_encoder.initialize_hidden_state()
    steps_per_epoch = int(hparas['N_SAMPLE']/hparas['BATCH_SIZE'])
    
    for epoch in range(hparas['N_EPOCH']):
        g_total_loss = 0
        d_total_loss = 0
        start = time.time()
        
        for image, caption in dataset:
            g_loss, d_loss = train_step(image, caption, hidden)
            g_total_loss += g_loss
            d_total_loss += d_loss
            
        time_tuple = time.localtime()
        time_string = time.strftime("%m/%d/%Y, %H:%M:%S", time_tuple)
            
        print("Epoch {}, gen_loss: {:.4f}, disc_loss: {:.4f}".format(epoch+1,
                                                                     g_total_loss/steps_per_epoch,
                                                                     d_total_loss/steps_per_epoch))
        print('Time for epoch {} is {:.4f} sec'.format(epoch+1, time.time()-start))
        
        # save the model
        if (epoch + 1) % 10 == 0:
            checkpoint.save(file_prefix = checkpoint_prefix)
        
        # visualization
        if (epoch + 1) % hparas['PRINT_FREQ'] == 0:
            for caption in sample_sentence:
                fake_image = test_step(caption, sample_seed, hidden)
            save_images(fake_image, [ni, ni], 'samples/demo/train_{:02d}.jpg'.format(epoch))

In [67]:
train(dataset, hparas['N_EPOCH'])

Epoch 1, gen_loss: 0.4707, disc_loss: 1.0885
Time for epoch 1 is 18.1674 sec
Epoch 2, gen_loss: 0.5574, disc_loss: 0.9490
Time for epoch 2 is 16.9456 sec
Epoch 3, gen_loss: 0.7402, disc_loss: 0.7629
Time for epoch 3 is 17.3516 sec
Epoch 4, gen_loss: 1.3450, disc_loss: 0.3981
Time for epoch 4 is 18.2550 sec
Epoch 5, gen_loss: 2.0509, disc_loss: 0.1843
Time for epoch 5 is 17.4844 sec
Epoch 6, gen_loss: 1.8685, disc_loss: 0.2176
Time for epoch 6 is 16.4693 sec
Epoch 7, gen_loss: 2.4892, disc_loss: 0.1314
Time for epoch 7 is 17.3523 sec
Epoch 8, gen_loss: 2.6612, disc_loss: 0.1400
Time for epoch 8 is 15.7512 sec
Epoch 9, gen_loss: 2.9363, disc_loss: 0.1452
Time for epoch 9 is 16.0982 sec
Epoch 10, gen_loss: 3.5661, disc_loss: 0.1384
Time for epoch 10 is 15.2816 sec


## Evaluation

### Testing dataset

In [68]:
def testing_data_generator(caption, index):
    caption = tf.cast(caption, tf.float32)
    return caption, index

def testing_dataset_generator(batch_size, data_generator):
    data = pd.read_pickle('./2024-datalab-cup3-reverse-image-caption/dataset/testData.pkl')
    captions = data['Captions'].values
    caption = []
    for i in range(len(captions)):
        caption.append(captions[i])
    caption = np.asarray(caption)
    caption = caption.astype(np.int)
    index = data['ID'].values
    index = np.asarray(index)
    
    dataset = tf.data.Dataset.from_tensor_slices((caption, index))
    dataset = dataset.map(data_generator, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    dataset = dataset.repeat().batch(batch_size)
    
    return dataset

In [69]:
testing_dataset = testing_dataset_generator(hparas['BATCH_SIZE'], testing_data_generator)

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if sys.path[0] == "":


In [70]:
data = pd.read_pickle('./2024-datalab-cup3-reverse-image-caption/dataset/testData.pkl')
captions = data['Captions'].values

NUM_TEST = len(captions)
EPOCH_TEST = int(NUM_TEST / hparas['BATCH_SIZE'])

### Inference

In [71]:
if not os.path.exists('./inference/demo'):
    os.makedirs('./inference/demo')

In [72]:
def inference(dataset):
    hidden = text_encoder.initialize_hidden_state()
    sample_size = hparas['BATCH_SIZE']
    sample_seed = np.random.normal(loc=0.0, scale=1.0, size=(sample_size, hparas['Z_DIM'])).astype(np.float32)
    
    step = 0
    start = time.time()
    for captions, idx in dataset:
        if step > EPOCH_TEST:
            break
        
        fake_image = test_step(captions, sample_seed, hidden)
        step += 1
        for i in range(hparas['BATCH_SIZE']):
            plt.imsave('./inference/demo/inference_{:04d}.jpg'.format(idx[i]), fake_image[i].numpy()*0.5 + 0.5)
            
    print('Time for inference is {:.4f} sec'.format(time.time()-start))

In [73]:
checkpoint.restore(checkpoint_dir + '/ckpt-1')

<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus at 0x21271eede08>

In [74]:
inference(testing_dataset)

Time for inference is 3.0621 sec


## Inception score

In [84]:
from numba import cuda

cuda.get_current_device().reset()

In [85]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
score_file = './score.csv'
if os.path.exists(score_file):
    os.remove(score_file)

In [87]:
%cd testing
!python inception_score.py ../inference/demo/ ../score.csv 39
%cd ..

c:\Users\user\OneDrive - NTHU\桌面\DL\Comp3\testing
1 Physical GPUs, 1 Logical GPUs
['C:\\Users\\user\\OneDrive - NTHU\\桌面\\DL\\Comp3\\inference\\demo\\inference_0023.jpg'
 'C:\\Users\\user\\OneDrive - NTHU\\桌面\\DL\\Comp3\\inference\\demo\\inference_0041.jpg'
 'C:\\Users\\user\\OneDrive - NTHU\\桌面\\DL\\Comp3\\inference\\demo\\inference_0057.jpg'
 'C:\\Users\\user\\OneDrive - NTHU\\桌面\\DL\\Comp3\\inference\\demo\\inference_0058.jpg'
 'C:\\Users\\user\\OneDrive - NTHU\\桌面\\DL\\Comp3\\inference\\demo\\inference_0059.jpg'
 'C:\\Users\\user\\OneDrive - NTHU\\桌面\\DL\\Comp3\\inference\\demo\\inference_0068.jpg'
 'C:\\Users\\user\\OneDrive - NTHU\\桌面\\DL\\Comp3\\inference\\demo\\inference_0072.jpg'
 'C:\\Users\\user\\OneDrive - NTHU\\桌面\\DL\\Comp3\\inference\\demo\\inference_0078.jpg'
 'C:\\Users\\user\\OneDrive - NTHU\\桌面\\DL\\Comp3\\inference\\demo\\inference_0081.jpg'
 'C:\\Users\\user\\OneDrive - NTHU\\桌面\\DL\\Comp3\\inference\\demo\\inference_0089.jpg'
 'C:\\Users\\user\\OneDrive - NTHU\\桌面

In [88]:
import os
import pandas as pd
import numpy as np

if os.path.exists(score_file):
    df_score = pd.read_csv(score_file)
    mean_score = np.mean(df_score['score'].values)
    print(f'Mean Score: {mean_score:f}')
else:
    print('Evaluation Failed!')

Mean Score: 1.035121
