# Stack GAN

In [1]:
!export TF_CUDNN_USE_AUTOTUNE=0

In [2]:
import os
import pickle
import time
import random
import PIL
import numpy as np
import pandas as pd
import tensorflow as tf

from PIL import Image
from tensorflow import keras
from tensorflow.keras import layers
from matplotlib import pyplot as plt

2023-12-16 22:39:52.865648: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [3]:
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        # Restrict TensorFlow to only use the first GPU
        tf.config.experimental.set_visible_devices(gpus[1], 'GPU')

        # Currently, memory growth needs to be the same across GPUs
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        logical_gpus = tf.config.experimental.list_logical_devices('GPU')
        print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    except RuntimeError as e:
        # Memory growth must be set before GPUs have been initialized
        print(e)

2 Physical GPUs, 1 Logical GPUs


2023-12-16 22:39:54.347727: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:980] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-12-16 22:39:54.347850: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:980] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-12-16 22:39:54.352072: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:980] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-12-16 22:39:54.352208: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:980] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-12-16 22:39:54.352317: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:980] successful NUMA node read from S

# Preprocess Text

In [4]:
import re
import string

In [5]:
dictionary_path = './dictionary'
vocab = np.load(dictionary_path + '/vocab.npy')
print('there are {} vocabularies in total'.format(len(vocab)))

word2Id_dict = dict(np.load(dictionary_path + '/word2Id.npy'))
id2word_dict = dict(np.load(dictionary_path + '/id2Word.npy'))
print('Word to id mapping, for example: %s -> %s' % ('flower', word2Id_dict['flower']))
print('Id to word mapping, for example: %s -> %s' % ('1', id2word_dict['1']))
print('Tokens: <PAD>: %s; <RARE>: %s' % (word2Id_dict['<PAD>'], word2Id_dict['<RARE>']))

there are 5427 vocabularies in total
Word to id mapping, for example: flower -> 1
Id to word mapping, for example: 1 -> flower
Tokens: <PAD>: 5427; <RARE>: 5428


In [6]:
def sent2IdList(line, MAX_SEQ_LENGTH=20):
    MAX_SEQ_LIMIT = MAX_SEQ_LENGTH
    padding = 0
    
    # data preprocessing, remove all puntuation in the texts
    prep_line = re.sub('[%s]' % re.escape(string.punctuation), ' ', line.rstrip())
    prep_line = prep_line.replace('-', ' ')
    prep_line = prep_line.replace('-', ' ')
    prep_line = prep_line.replace('  ', ' ')
    prep_line = prep_line.replace('.', '')
    tokens = prep_line.split(' ')
    tokens = [
        tokens[i] for i in range(len(tokens))
        if tokens[i] != ' ' and tokens[i] != ''
    ]
    l = len(tokens)
    padding = MAX_SEQ_LIMIT - l
    
    # make sure length of each text is equal to MAX_SEQ_LENGTH, and replace the less common word with <RARE> token
    for i in range(padding):
        tokens.append('<PAD>')
    line = [
        word2Id_dict[tokens[k]]
        if tokens[k] in word2Id_dict else word2Id_dict['<RARE>']
        for k in range(len(tokens))
    ]

    return line

text = "the flower shown has yellow anther red pistil and bright red petals."
print(text)
print(sent2IdList(text))

the flower shown has yellow anther red pistil and bright red petals.
['9', '1', '82', '5', '11', '70', '20', '31', '3', '29', '20', '2', '5427', '5427', '5427', '5427', '5427', '5427', '5427', '5427']


In [7]:
EMBEDDING_DIM = 1024
IMAGE_HEIGHT = 64
IMAGE_WIDTH = 64
IMAGE_CHANNEL = 3
vocabulary_size = 0

caption_embedding_model = tf.keras.Sequential([
    tf.keras.layers.Embedding(input_dim=len(word2Id_dict), output_dim=EMBEDDING_DIM, input_length=20),
])

def training_data_generator(image_path, caption):
    # load in the image according to image path
    img = tf.io.read_file(image_path)
    img = tf.image.decode_image(img, channels=3)
    img = tf.image.convert_image_dtype(img, tf.float32)
    img.set_shape([None, None, 3])
    img = tf.image.resize(img, size=[IMAGE_HEIGHT, IMAGE_WIDTH])
    img.set_shape([IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNEL])
    caption = tf.cast(caption, tf.int32)
    caption_embedding = caption_embedding_model(caption)
    caption_embedding = tf.reduce_mean(caption_embedding, axis=0)
    
    return img, caption_embedding

def dataset_generator(filenames, batch_size, data_generator):
    # load the training data into two NumPy arrays
    df = pd.read_pickle(filenames)
    captions = df['Captions'].values
    caption = []
    # each image has 1 to 10 corresponding captions
    # we choose one of them randomly for training
    for i in range(len(captions)):
        caption.append(random.choice(captions[i]))
    caption = np.asarray(caption)
    caption = caption.astype(np.int)
    image_path = df['ImagePath'].values
    
    # assume that each row of `features` corresponds to the same row as `labels`.
    assert caption.shape[0] == image_path.shape[0]
    
    dataset = tf.data.Dataset.from_tensor_slices((image_path, caption))
    dataset = dataset.map(data_generator, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    dataset = dataset.shuffle(len(caption)).batch(batch_size, drop_remainder=True)
    dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

    return dataset


In [8]:
def testing_data_generator(caption, index):
    caption = tf.cast(caption, tf.float32)
    caption_embedding = caption_embedding_model(caption)
    caption_embedding = tf.reduce_mean(caption_embedding, axis=0)
    return caption, caption_embedding

def testing_dataset_generator(batch_size, data_generator):
    data = pd.read_pickle('./dataset/testData.pkl')
    captions = data['Captions'].values
    caption = []
    for i in range(len(captions)):
        caption.append(captions[i])
    caption = np.asarray(caption)
    caption = caption.astype(np.int)
    index = data['ID'].values
    index = np.asarray(index)
    
    dataset = tf.data.Dataset.from_tensor_slices((caption, index))
    dataset = dataset.map(data_generator, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    dataset = dataset.repeat().batch(batch_size)
    
    return dataset

# stage 1

In [9]:
def generator_c(x):
    mean = x[:,:128]
    log_sigma = x[:,128:]
    stddev = tf.exp(log_sigma)
    epsilon = tf.random.normal((mean.shape[1],),dtype=tf.int32)
    c = stddev * epsilon + mean
    return c

class CA(keras.Model):
    """
    Get conditioning augmentation model.
    Takes an embedding of shape (1024,) and returns a tensor of shape (256,)
    """
    def __init__(self):
        super(CA,self).__init__()
        self.fc = layers.Dense(256)
        self.activation  = layers.LeakyReLU(alpha=0.2)
    def call(self,inputs,training=False):
        x = self.activation(self.fc(inputs))
        return x

class Embedding_Compressor(keras.Model):
    """
    Build embedding compressor model
    """
    def __init__(self):
        super(Embedding_Compressor,self).__init__()
        self.fc = layers.Dense(128)
        self.activation = layers.ReLU()
    def call(self,inputs,training=False):
        x = self.activation(self.fc(inputs))
        return x

class Generator_stage1(keras.Model):
    """
    Builds a generator model used in Stage-I
    """
    def __init__(self):
        super(Generator_stage1,self).__init__()
        self.ca_fc = layers.Dense(256)
        self.ca_activation = layers.LeakyReLU(alpha=0.2)
        #self.lambda1 = layers.Lambda(generator_c)
        #self.mean1 = layers.Dense(128)
        #self.log_sigma1 = layers.Dense(128)
        self.fc1 = layers.Dense(128 * 8 * 4 * 4,use_bias=False)
        self.activation = layers.ReLU()
        
        self.upsampling1 = layers.UpSampling2D(size=(2,2))
        self.conv1 = layers.Conv2D(512,kernel_size=3,strides=1,padding='same',use_bias=False)
        self.bn1 = layers.BatchNormalization()
        self.ac1 = layers.ReLU()
        
        self.upsampling2 = layers.UpSampling2D(size=(2,2))
        self.conv2 = layers.Conv2D(256,kernel_size=3,strides=1,padding='same',use_bias=False)
        self.bn2 = layers.BatchNormalization()
        self.ac2 = layers.ReLU()

        self.upsampling3 = layers.UpSampling2D(size=(2,2))
        self.conv3 = layers.Conv2D(128,kernel_size=3,strides=1,padding='same',use_bias=False)
        self.bn3 = layers.BatchNormalization()
        self.ac3 = layers.ReLU()

        self.upsampling4 = layers.UpSampling2D(size=(2,2))
        self.conv4 = layers.Conv2D(64,kernel_size=3,strides=1,padding='same',use_bias=False)
        self.bn4 = layers.BatchNormalization()
        self.ac4 = layers.ReLU()

        self.conv5 = layers.Conv2D(3,kernel_size=3,strides=1,padding='same',use_bias=False)

    def call(self,inputs,training=False):
        mean_logsigma = tf.split(self.ca_activation(self.ca_fc(inputs[0])),num_or_size_splits=2,axis=-1)
        #print(mean_logsigma.shape)
        #c = self.lambda1(mean_logsigma)
        #mean_logsigma_split = tf.split(mean_logsigma,num_or_size_splits=2,axis=-1)
        mean = mean_logsigma[0]
        log_sigma = mean_logsigma[1]
        stddev = tf.exp(log_sigma)
        c = stddev * inputs[2] + mean
        #print(c.shape)
        gen_inputs = tf.concat([c,inputs[1]],axis=1)
        #print(gen_inputs.shape)
        x = self.activation(self.fc1(gen_inputs))
        #print(x.shape)
        x = tf.reshape(x,shape=(-1,4,4,128*8))
        x = self.ac1(self.bn1(self.conv1(self.upsampling1(x)),training=training))
        x = self.ac2(self.bn2(self.conv2(self.upsampling2(x)),training=training))
        x = self.ac3(self.bn3(self.conv3(self.upsampling3(x)),training=training))
        x = self.ac4(self.bn4(self.conv4(self.upsampling4(x)),training=training))
        x = self.conv5(x)
        x = tf.tanh(x)
        #print(x.shape)
        return x,mean_logsigma

class Discriminator_stage1(keras.Model):
    def __init__(self):
        super(Discriminator_stage1,self).__init__()
        self.e_fc = layers.Dense(128)
        self.e_ac = layers.LeakyReLU(alpha=0.2)
        
        self.conv1 = layers.Conv2D(64,kernel_size=(4,4),padding='same',strides=2,use_bias=False)
        self.ac1 = layers.LeakyReLU(alpha=0.2)

        self.conv2 = layers.Conv2D(128,kernel_size=(4,4),padding='same',strides=2,use_bias=False)
        self.bn1 = layers.BatchNormalization()
        self.ac2 = layers.LeakyReLU(alpha=0.2)

        self.conv3 = layers.Conv2D(256,kernel_size=(4,4),padding='same',strides=2,use_bias=False)
        self.bn2 = layers.BatchNormalization()
        self.ac3 = layers.LeakyReLU(alpha=0.2)

        self.conv4 = layers.Conv2D(512,kernel_size=(4,4),padding='same',strides=2,use_bias=False)
        self.bn3 = layers.BatchNormalization()
        self.ac4 = layers.LeakyReLU(alpha=0.2)

        self.conv5 = layers.Conv2D(512,kernel_size=1,padding='same',strides=1)
        self.bn4 = layers.BatchNormalization()
        self.ac5 = layers.LeakyReLU(alpha=0.2)

        self.flatten = layers.Flatten()
        self.fc = layers.Dense(1)
    def call(self,inputs,training=False):
        x = self.ac1(self.conv1(inputs[0]))
        #print(x.shape)
        x = self.ac2(self.bn1(self.conv2(x),training=training))
        #print(x.shape)
        x = self.ac3(self.bn2(self.conv3(x),training=training))
        #print(x.shape)
        x = self.ac4(self.bn3(self.conv4(x),training=training))
        #print(x.shape)
        #print(x.shape)
        input_layer2 = self.e_ac(self.e_fc(inputs[1]))
        #print(input_layer2.shape)
        input_layer2 = tf.reshape(input_layer2,shape=(-1,1,1,128))
        #print(input_layer2.shape)
        input_layer2 = tf.tile(input_layer2,[1,4,4,1])
        #print(input_layer2.shape)
        x = tf.concat([x,input_layer2],axis=-1)
        #print(x.shape)
        x = self.ac5(self.bn4(self.conv5(x),training=training))
        #print(x.shape)
        x = self.flatten(x)
        x = self.fc(x)
        #print(x.shape)
        x = tf.sigmoid(x)
        return x

# Stage-II

In [10]:
class Residual_block(layers.Layer):
    def __init__(self):
        super(Residual_block,self).__init__()
        self.conv1 = layers.Conv2D(128*4,kernel_size=(3,3),padding='same',stride=1)
        self.bn1 = layers.BatchNormalization()
        self.ac1 = layers.ReLU()
        self.conv2 = layers.Conv2D(128*4,kernel_size=(3,3),padding='same',strides=1)
        self.bn2 = layers.BatchNormalization()
        self.ac2 = layers.ReLU()

    def call(self,inputs,training=False):
        x = self.bn1(self.conv1(inputs),training=training)
        x = self.ac1(x)
        x = self.bn2(self.conv2(x),training=training)
        x = layers.add([x,inputs])
        x = self.ac2(x)
        return x

class Generator_stage2(keras.Model):
    def __init__(self):
        super(Generator_stage2,self).__init__()
        self.ca_fc = layers.Dense(256)
        self.ca_activation = layers.LeakyReLU(alpha=0.2)
        #self.mean1 = layers.Dense(128)
        #self.log_sigma1 = layers.Dense(128)

        self.conv1 = layers.Conv2D(128,kernel_size=(3,3),strides=1,padding='same',use_bias=False)
        self.ac1 = layers.ReLU()
        self.conv2 = layers.Conv2D(256,kernel_size=(4,4),strides=2,padding='same',use_bias=False)
        self.bn1 = layers.BatchNormalization()
        self.ac2 = layers.ReLU()
        self.conv3 = layers.Conv2D(512,kernel_size=(4,4),strides=2,padding='same',use_bias=False)
        self.bn2 = layers.BatchNormalization()
        self.ac3 = layers.ReLU()

        self.conv4 = layers.Conv2D(512,kernel_size=(3,3),strides=1,padding='same',use_bias=False)
        self.bn3 = layers.BatchNormalization()
        self.ac4 = layers.ReLU()

        self.rb1 = Residual_block()
        self.rb2 = Residual_block()
        self.rb3 = Residual_block()
        self.rb4 = Residual_block()
        
        self.upsampling1 = layers.UpSampling2D(size=(2,2))
        self.conv5 = layers.Conv2D(512,kernel_size=3,strides=1,padding='same',use_bias=False)
        self.bn4 = layers.BatchNormalization()
        self.ac5 = layers.ReLU()

        self.upsampling2 = layers.UpSampling2D(size=(2,2))
        self.conv6 = layers.Conv2D(256,kernel_size=3,strides=1,padding='same',use_bias=False)
        self.bn5 = layers.BatchNormalization()
        self.ac6 = layers.ReLU()

        self.upsampling3 = layers.UpSampling2D(size=(2,2))
        self.conv7 = layers.Conv2D(128,kernel_size=3,strides=1,padding='same',use_bias=False)
        self.bn6 = layers.BatchNormalization()
        self.ac7 = layers.ReLU()

        self.upsampling4 = layers.UpSampling2D(size=(2,2))
        self.conv8 = layers.Conv2D(64,kernel_size=3,strides=1,padding='same',use_bias=False)
        self.bn7 = layers.BatchNormalization()
        self.ac8 = layers.ReLU()

        self.conv9 = layers.Conv2D(3,kernel_size=3,strides=1,padding='same',use_bias=False)

    def call(self,inputs,training):
        #CA Network
        mean_logsigma = tf.split(self.ca_activation(self.ca_fc(inputs[0])),num_or_size_splits=2,axis=-1)
        #mean_logsigma = self.ca_activation(self.ca_fc(inputs[0]))
        mean = mean_logsigma[0]
        log_sigma = mean_logsigma[1]
        stddev = tf.exp(log_sigma)
        c = stddev * inputs[2] + mean
        #c = tf.concat([c,inputs[1]],axis=1)
        #Image Encoder
        x = self.ac1(self.conv1(inputs[1]))
        x = self.ac2(self.bn1(self.conv2(x),training=training))
        x = self.ac3(self.bn2(self.conv3(x),training=training))
        c = tf.expand_dims(c,axis=1)
        c = tf.expand_dims(c,axis=1)
        c = tf.tile(c,[1,16,16,1])
        #Concatenation
        c_code = tf.concat([c,x],axis=3)
        #Residual Block
        x = self.ac4(self.bn3(self.conv4(c_code),training=training))
        x = self.rb1(x)
        x = self.rb2(x)
        x = self.rb3(x)
        x = self.rb4(x)
        #Upsampling block
        x = self.ac5(self.bn4(self.conv5(self.upsampling1(x)),training=training))
        x = self.ac6(self.bn5(self.conv6(self.upsampling2(x)),training=training))
        x = self.ac7(self.bn6(self.conv7(self.upsampling3(x)),training=training))
        x = self.ac8(self.bn7(self.conv8(self.upsampling4(x)),training=training))
        x = self.conv9(x)
        x = tf.tanh(x)
        return x,mean_logsigma

class Discriminator_stage2(keras.Model):
    def __init__(self):
        super(Discriminator_stage2,self).__init__()
        self.e_fc = layers.Dense(128)
        self.e_ac = layers.LeakyReLU(alpha=0.2)
        
        self.conv1 = layers.Conv2D(64,kernel_size=(4,4),strides=2,padding='same',use_bias=False)
        self.ac1 = layers.LeakyReLU(alpha=0.2)
        
        self.conv2 = layers.Conv2D(128,kernel_size=(4,4),strides=2,padding='same',use_bias=False)
        self.bn1 = layers.BatchNormalization()
        self.ac2 = layers.LeakyReLU(alpha=0.2)

        self.conv3 = layers.Conv2D(256,kernel_size=(4,4),strides=2,padding='same',use_bias=False)
        self.bn2 = layers.BatchNormalization()
        self.ac3 = layers.LeakyReLU(alpha=0.2)

        self.conv4 = layers.Conv2D(512,kernel_size=(4,4),strides=2,padding='same',use_bias=False)
        self.bn3 = layers.BatchNormalization()
        self.ac4 = layers.LeakyReLU(alpha=0.2)

        self.conv5 = layers.Conv2D(1024,kernel_size=(4,4),strides=2,padding='same',use_bias=False)
        self.bn4 = layers.BatchNormalization()
        self.ac5 = layers.LeakyReLU(alpha=0.2)

        self.conv6 = layers.Conv2D(2048,kernel_size=(4,4),strides=2,padding='same',use_bias=False)
        self.bn5 = layers.BatchNormalization()
        self.ac6 = layers.LeakyReLU(alpha=0.2)

        self.conv7 = layers.Conv2D(1024,kernel_size=(1,1),strides=1,padding='same',use_bias=False)
        self.bn6 = layers.BatchNormalization()
        self.ac7 = layers.LeakyReLU(alpha=0.2)

        self.conv8 = layers.Conv2D(512,kernel_size=(1,1),strides=1,padding='same',use_bias=False)
        self.bn7 = layers.BatchNormalization()

        self.conv9 = layers.Conv2D(128,kernel_size=(1,1),strides=1,padding='same',use_bias=False)
        self.bn8 = layers.BatchNormalization()
        self.ac8 = layers.LeakyReLU(alpha=0.2)

        self.conv10 = layers.Conv2D(128,kernel_size=(3,3),strides=1,padding='same',use_bias=False)
        self.bn9 = layers.BatchNormalization()
        self.ac9 = layers.LeakyReLU(alpha=0.2)

        self.conv11 = layers.Conv2D(512,kernel_size=(3,3),strides=1,padding='same',use_bias=False)
        self.bn10 = layers.BatchNormalization()

        self.ac10 = layers.LeakyReLU(alpha=0.2)

        self.conv12 = layers.Conv2D(64*8,kernel_size=1,strides=1,padding='same')
        self.bn11 = layers.BatchNormalization()
        self.ac11 = layers.LeakyReLU(alpha=0.2)

        self.flatten = layers.Flatten()
        self.fc = layers.Dense(1)

    def call(self,inputs,training=False):
        x = self.ac1(self.conv1(inputs[0]))
        x = self.ac2(self.bn1(self.conv2(x),training=training))
        x = self.ac3(self.bn2(self.conv3(x),training=training))
        x = self.ac4(self.bn3(self.conv4(x),training=training))
        x = self.ac5(self.bn4(self.conv5(x),training=training))
        x = self.ac6(self.bn5(self.conv6(x),training=training))
        x = self.ac7(self.bn6(self.conv7(x),training=training))
        x = self.bn7(self.conv8(x))
        
        x2 = self.ac8(self.bn8(self.conv9(x),training=training))
        x2 = self.ac9(self.bn9(self.conv10(x2),training=training))
        x2 = self.bn10(self.conv11(x2))

        added_x = layers.add([x,x2])
        added_x = self.ac10(added_x)

        input_layer2 = self.e_ac(self.e_fc(inputs[1]))
        input_layer2 = tf.reshape(input_layer2,shape=(-1,1,1,128))
        input_layer2 = tf.tile(input_layer2,[1,4,4,1])
        x3 = tf.concat([added_x,input_layer2],axis=-1)

        x3 = self.ac11(self.bn11(self.conv12(x3),training=training))
        x3 = self.faltten(x3)
        x3 = self.fc(x3)
        x3 = tf.sigmoid(x3)
        return x3

# LOSS

In [11]:
def celoss_zeros(logits):
	# 计算属于与标签为0的交叉熵，使用标签平滑
    y = tf.ones_like(logits) * 0.1
    loss = keras.losses.binary_crossentropy(y,logits)
    return tf.reduce_mean(loss)

def celoss_ones(logits):
    # 计算属于与标签为1的交叉熵，使用标签平滑
    y = tf.ones_like(logits) * 0.9
    loss = keras.losses.binary_crossentropy(y, logits)
    return tf.reduce_mean(loss)

def KL_loss(logits):
    mean = logits[0]
    logsigma = logits[1]
    loss = -logsigma + 0.5 * (-1 + tf.exp(2. * logsigma) + tf.square(mean))
    loss = tf.reduce_mean(loss)
    return loss

def d_loss_fn(batch_size,generator,discriminator,img_batch,embedding_batch,z_noise,condition_var,training):
    # 采样生成图片
    # print('======================================================')
    # print(embedding_batch.shape)
    fake_images,_ = generator([embedding_batch,z_noise,condition_var],training)
    # print('======================================================')
    # 判定生成图片
    d_fake_logits = discriminator([fake_images,embedding_batch], training)
    d_loss_fake = celoss_zeros(d_fake_logits)
    # 判定真实图片
    d_real_logits = discriminator([img_batch,embedding_batch], training)
    d_loss_real = celoss_ones(d_real_logits)
    # 判定不符嵌入
    d_wrong_logits = discriminator([img_batch[:(batch_size-1)],embedding_batch[1:]],training)
    d_loss_wrong = celoss_zeros(d_wrong_logits)
    loss = d_loss_fake + d_loss_real + d_loss_wrong
    return loss

def g_loss_fn(generator,discriminator,embedding_batch,z_noise,condition_var,training):
    fake_images,mean_logsigma = generator([embedding_batch,z_noise,condition_var],training)
    d_fake_logits = discriminator([fake_images,embedding_batch], training)
    d_loss_fake = celoss_ones(d_fake_logits)
    d_KL_fake = KL_loss(mean_logsigma)
    loss = d_loss_fake + 2.0 * d_KL_fake
    return loss

def d_loss_fn_stage2(batch_size=64,
                     gen_stage1=None,
                     gen_stage2=None,
                     dis_stage2=None,
                     image_batch=None,
                     embedding_batch=None,
                     z_noise=None,
                     condition_var=None,
                     training=False):
    lr_fake_images,_ = gen_stage1([embedding_batch,z_noise,condition_var])
    hr_fake_images,_ = gen_stage2([embedding_batch,lr_fake_images,condition_var],training)
    # 判定生成图片
    d_fake_logits = dis_stage2([hr_fake_images,embedding_batch], training)
    d_loss_fake = celoss_zeros(d_fake_logits)
    # 判定真实图片
    d_real_logits = dis_stage2([image_batch,embedding_batch], training)
    d_loss_real = celoss_ones(d_real_logits)
    # 判定不符嵌入
    d_wrong_logits = dis_stage2([image_batch[:(batch_size-1)],embedding_batch[1:]],training)
    d_loss_wrong = celoss_zeros(d_wrong_logits)
    loss = d_loss_fake + d_loss_real + d_loss_wrong
    return loss

def g_loss_fn_stage2(gen_stage1=None,
                     gen_stage2=None,
                     dis_stage2=None,
                     embedding_batch=None,
                     z_noise=None,
                     condition_var=None,
                     training=False):
    lr_fake_images,_ = gen_stage1([embedding_batch,z_noise,condition_var])
    hr_fake_images,mean_logsigma = gen_stage2([embedding_batch,lr_fake_images,condition_var],training)
    d_fake_logits = dis_stage2([hr_fake_images,embedding_batch], training)
    d_loss_fake = celoss_ones(d_fake_logits)
    d_KL_fake = KL_loss(mean_logsigma)
    loss = d_loss_fake + 2.0 * d_KL_fake
    return loss

# 图片保存函数

In [12]:
def save_result(val_out,val_block_size,image_path,color_mode):
    def preprocessing(img):
        img = ((img + 1.0)*(255./2)).astype(np.uint8)
        return img

    preprocessed = preprocessing(val_out)
    final_image = np.array([])
    single_row = np.array([])
    for b in range(val_out.shape[0]):
        # concat image into a row
        if single_row.size == 0:
            single_row = preprocessed[b,:,:,:]
        else:
            single_row = np.concatenate((single_row,preprocessed[b,:,:,:]),axis=1)
        # concat image row to final_image
        if (b+1) % val_block_size == 0:
            if final_image.size == 0:
                final_image = single_row
            else:
                final_image = np.concatenate((final_image, single_row), axis=0)

            # reset single row
            single_row = np.array([])

    if final_image.shape[2] == 1:
        final_image = np.squeeze(final_image, axis=2)
    Image.fromarray(final_image).save(image_path)

# 模型训练

In [13]:
import time
z = tf.random.normal([64,100])
db_test = testing_dataset_generator(8, testing_data_generator)
db_test = iter(db_test)
caption ,embeddings_test = next(db_test)
embeddings_test = np.repeat(embeddings_test, 8, axis=0)

caption = caption.numpy()
print('==================================================')
for i, c in enumerate(caption):
    s = ""
    for id in c :
        s_id = str(int(id))
        s += id2word_dict[s_id] + " "
    print(f'index: {i}   caption: {s}')

print('==================================================')

def main_stage1():
    data_dir = "./birds/"
    train_dir = data_dir + "/train"
    test_dir = data_dir + "/test"
    image_size = 64
    batch_size = 64
    z_dim = 100
    stage1_generator_lr = 0.0002
    stage1_discriminator_lr = 0.0002
    stage1_lr_decay_step = 600
    epochs = 10000
    condition_dim = 128
    training=True

    d_optimizer = keras.optimizers.Adam(lr=stage1_discriminator_lr, beta_1=0.5, beta_2=0.999)
    g_optimizer = keras.optimizers.Adam(lr=stage1_generator_lr, beta_1=0.5, beta_2=0.999)


    db_train = dataset_generator('./dataset/text2ImgData.pkl', batch_size, training_data_generator)



    
    gen = Generator_stage1()
    gen.build([[4,1024],[4,100],[128]]) # [embedding_batch,z_noise,condition_var]
    try:
        gen.load_weights("stage1_gen.h5")
    except Exception as e:
        print(e)
    
    dis = Discriminator_stage1()
    dis.build([[4,64,64,3],[4,1024]])
    try:
        dis.load_weights("stage1_dis.h5")
    except Exception as e:
        print(e)
    #real_labels = np.ones((batch_size, 1), dtype=float) * 0.9
    #fake_labels = np.zeros((batch_size, 1), dtype=float) * 0.1
    for epoch in range(epochs):
        g_losses = []
        d_losses = []
        for index,(x,embedding) in enumerate(db_train):
            z_noise = tf.random.normal(shape=(batch_size,z_dim))
            condition_var = tf.random.normal(shape=(condition_dim,))
            with tf.GradientTape() as tape:
                d_loss = d_loss_fn(batch_size,gen,dis,x,embedding,z_noise,condition_var,training)
            grads = tape.gradient(d_loss,dis.trainable_variables)
            d_optimizer.apply_gradients(zip(grads,dis.trainable_variables))
            z_noise = tf.random.normal(shape=(batch_size,z_dim))
            condition_var = tf.random.normal(shape=(condition_dim,))
            with tf.GradientTape() as tape:
                g_loss = g_loss_fn(gen,dis,embedding,z_noise,condition_var,training)
            grads = tape.gradient(g_loss,gen.trainable_variables)
            g_optimizer.apply_gradients(zip(grads,gen.trainable_variables))
            # print(f'batch: {index}  // d_loss: {d_loss}  // g_loss: {g_loss}')
        if epoch % 2 == 0:
            print(epoch,'d_loss:',float(d_loss),'g_loss:',float(g_loss))
            #可视化
            condition_var = tf.random.normal(shape=(condition_dim,))
            fake_image,_ = gen([embeddings_test,z,condition_var],training=False)
            img_path = r'testout/ganstage1-{}.png'.format(epoch)
            save_result(fake_image.numpy(),8,img_path,color_mode='P')
            d_losses.append(float(d_loss))
            g_losses.append(float(g_loss))
        if epoch % 5 == 0:
            timestamp = int(time.time())
            gen.save_weights(f"./weight/stage1_gen_{epoch}_{timestamp}.h5")
            dis.save_weights(f"./weight/stage1_dis_{epoch}_{timestamp}.h5")

def main_stage2():
    data_dir = "data/birds/"
    train_dir = data_dir + "/train"
    test_dir = data_dir + "/test"
    image_size = 256
    batch_size = 64
    z_dim = 100
    stage1_generator_lr = 0.0002
    stage1_discriminator_lr = 0.0002
    stage1_lr_decay_step = 600
    epochs = 10000
    condition_dim = 128
    training=True



    d_optimizer = keras.optimizers.Adam(lr=stage1_discriminator_lr, beta_1=0.5, beta_2=0.999)
    g_optimizer = keras.optimizers.Adam(lr=stage1_generator_lr, beta_1=0.5, beta_2=0.999)

    #Load dataset
    db_hr_train = dataset_generator('./dataset/text2ImgData.pkl', batch_size, training_data_generator)
    

    # db_hr_test = testing_dataset_generator(8, testing_data_generator)
    # db_hr_test = iter(db_hr_test)
    # embeddings_test = next(db_hr_test)
    # embeddings_test = np.repeat(embeddings_test, 8, axis=0)
    

    gen_stage1 = Generator_stage1()
    gen_stage1.build([[4,1024],[4,100],[128]])
    try:
        gen_stage1.load_weights("stage1_gen.h5")
    except Exception as e:
        print(e)

    gen_stage2 = Generator_stage2()
    gen_stage2.build([[4,1024],[4,64,64,3],[128]])
    try:
        gen_stage2.load_weights("stage2_gen.h5")
    except Exception as e:
        print(e)

    dis_stage2 = Discriminator_stage2()
    dis_stage2.build([[4,256,256,3],[4,1024]])
    try:
        dis_stage2.load_weights("stage2_dis.h5")
    except Exception as e:
        print(e)
        
    for epoch in range(epochs):
        g_losses = []
        d_losses = []
        for index,(x,embedding) in enumerate(db_hr_train):
            z_noise = tf.random.normal(shape=(batch_size,z_dim))
            condition_var = tf.random.normal(shape=(condition_dim,))
            with tf.GradientTape() as tape:
                d_loss = d_loss_fn_stage2(batch_size=batch_size,
                                          gen_stage1=gen_stage1,
                                          gen_stage2=gen_stage2,
                                          dis_stage2=dis_stage2,
                                          image_batch=x,
                                          embedding_batch=embedding,
                                          z_noise=z_noise,
                                          condition_var=condition_var,
                                          training=training)
            grads = tape.gradient(d_loss,dis_stage2.trainable_variables)
            d_optimizer.apply_gradients(zip(grads,dis_stage2.trainable_variables))
            z_noise = tf.random.normal(shape=(batch_size,z_dim))
            condition_var = tf.random.normal(shape=(condition_dim,))
            with tf.GradientTape() as tape:
                g_loss = g_loss_fn_stage2(gen_stage1=gen_stage1,
                                          gen_stage2=gen_stage2,
                                          dis_stage2=dis_stage2,
                                          embedding_batch=embedding,
                                          z_noise=z_noise,
                                          condition_var=condition_var,
                                          training=training)
            grads = tape.gradient(g_loss,gen_stage2.trainable_variables)
            g_optimizer.apply_gradients(zip(grads,gen_stage2.trainable_variables))
            # print(f'batch: {index}  // d_loss: {d_loss}  // g_loss: {g_loss}')
        if epoch % 10 == 0:
            print(epoch,'d_loss:',float(d_loss),'g_loss:',float(g_loss))
            #可视化
            condition_var = tf.random.normal(shape=(condition_dim,))
            lr_fake_image,_ = gen_stage1([embeddings_test,z,condition_var],training=False)
            hr_fake_image,_ = gen_stage2([embeddings_test,lr_fake_image,condition_var],training=False)
            img_path = r'testout/ganstage2-{}.png'.format(epoch)
            save_result(hr_fake_image.numpy(),8,img_path,color_mode='P')
            d_losses.append(float(d_loss))
            g_losses.append(float(g_loss))
        if epoch % 5 == 0:
            timestamp = int(time.time())
            gen_stage2.save_weights(f"./weight/stage2_gen_{epoch}_{timestamp}.h5")
            dis_stage2.save_weights(f"./weight/stage2_dis_{epoch}_{timestamp}.h5")


In [None]:
main_stage1()

In [None]:
main_stage2()