# 改进的文本到图像生成模型 (Text-to-Image GAN)

## 主要改进点

本notebook基于比赛PDF中的hints进行了以下改进：

### 1. 数据增强 (Data Augmentation)
- 随机裁剪和翻转
- 亮度和对比度调整
- 颜色抖动
- 提高模型泛化能力，防止过拟合

### 2. 改进的模型架构
- **Text Encoder**: 使用双向GRU + 多层RNN结构，更好地捕捉文本语义
- **Generator**: 采用DCGAN架构，使用转置卷积逐步生成64x64图像
- **Discriminator**: CNN架构 + Dropout + Batch Normalization，提高判别能力

### 3. 更复杂的损失函数
- 标签平滑 (Label Smoothing) - 防止判别器过于自信
- WGAN-GP选项 - 更稳定的训练（可选）
- 梯度惩罚 (Gradient Penalty) - 改善训练稳定性

### 4. 训练技巧
- 梯度裁剪 - 防止梯度爆炸
- 学习率调度 - 动态调整学习率
- 最佳模型保存 - 保存表现最好的模型
- 训练历史可视化 - 监控训练过程

### 5. 可选的高级技术
- 预训练词嵌入 (Word2Vec/GloVe)
- 多样性检查 - 防止模式崩溃
- 更强的数据增强

## 使用说明

1. 按顺序运行所有cell
2. 训练过程中会定期保存checkpoint和生成样本图像
3. 可以调整 `hparas` 中的超参数进行实验
4. 如果训练不稳定，可以尝试启用WGAN-GP (`USE_WGAN_GP = True`)

## 参考论文
- Generative Adversarial Text to Image Synthesis
- DCGAN: Unsupervised Representation Learning with Deep Convolutional GANs
- Improved Training of Wasserstein GANs

In [2]:
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
from tensorflow.keras import layers
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import string
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import PIL
import random
import time
from pathlib import Path

import re
from IPython import display

Preprocess Text

In [8]:
def sent2IdList(line, MAX_SEQ_LENGTH=20):
    MAX_SEQ_LIMIT = MAX_SEQ_LENGTH
    padding = 0

    # data preprocessing, remove all puntuation in the texts
    prep_line = re.sub('[%s]' % re.escape(string.punctuation), ' ', line.rstrip())
    prep_line = prep_line.replace('-', ' ')
    prep_line = prep_line.replace('-', ' ')
    prep_line = prep_line.replace('  ', ' ')
    prep_line = prep_line.replace('.', '')
    tokens = prep_line.split(' ')
    tokens = [
        tokens[i] for i in range(len(tokens))
        if tokens[i] != ' ' and tokens[i] != ''
    ]
    l = len(tokens)
    padding = MAX_SEQ_LIMIT - l

    # make sure length of each text is equal to MAX_SEQ_LENGTH, and replace the less common word with <RARE> token
    for i in range(padding):
        tokens.append('<PAD>')
    line = [
        word2Id_dict[tokens[k]]
        if tokens[k] in word2Id_dict else word2Id_dict['<RARE>']
        for k in range(len(tokens))
    ]

    return line

text = "the flower shown has yellow anther red pistil and bright red petals."
print(text)
print(sent2IdList(text))

the flower shown has yellow anther red pistil and bright red petals.
[np.str_('9'), np.str_('1'), np.str_('82'), np.str_('5'), np.str_('11'), np.str_('70'), np.str_('20'), np.str_('31'), np.str_('3'), np.str_('29'), np.str_('20'), np.str_('2'), np.str_('5427'), np.str_('5427'), np.str_('5427'), np.str_('5427'), np.str_('5427'), np.str_('5427'), np.str_('5427'), np.str_('5427')]


In [None]:
data_path = r'C:\Users\11958\Desktop\vscode-c\c\deep_learning\competition\competition3\2025-datalab-cup3-reverse-image-caption\dataset'
df = pd.read_pickle(os.path.join(data_path, 'text2ImgData.pkl'))
num_training_sample = len(df)
n_images_train = num_training_sample
print('There are %d image in training data' % (n_images_train))

There are 7370 image in training data


Create Dataset by Dataset API

In [12]:
print(df['ImagePath'].head())


ID
6734    ./102flowers/image_06734.jpg
6736    ./102flowers/image_06736.jpg
6737    ./102flowers/image_06737.jpg
6738    ./102flowers/image_06738.jpg
6739    ./102flowers/image_06739.jpg
Name: ImagePath, dtype: object


Conditional GAN Model

In [None]:
# 改进的Text Encoder - 使用双向GRU和更深的结构
class TextEncoder(tf.keras.Model):
    """
    改进的文本编码器
    - 使用双向GRU
    - 添加Dropout防止过拟合
    - 多层RNN结构
    """
    def __init__(self, hparas):
        super(TextEncoder, self).__init__()
        self.hparas = hparas
        self.batch_size = self.hparas['BATCH_SIZE']

        # embedding with tensorflow API
        self.embedding = layers.Embedding(self.hparas['VOCAB_SIZE'], self.hparas['EMBED_DIM'])
        self.dropout = layers.Dropout(0.3)
        
        # 使用双向GRU获取更好的文本表示
        self.bi_gru = layers.Bidirectional(
            layers.GRU(self.hparas['RNN_HIDDEN_SIZE'],
                      return_sequences=True,
                      return_state=False,
                      recurrent_initializer='glorot_uniform',
                      recurrent_dropout=0.2)
        )
        
        # 第二层GRU用于进一步提取特征
        self.gru = layers.GRU(self.hparas['RNN_HIDDEN_SIZE'] * 2,
                              return_sequences=True,
                              return_state=True,
                              recurrent_initializer='glorot_uniform',
                              recurrent_dropout=0.2)

    def call(self, text, hidden, training=True):
        text = self.embedding(text)
        text = self.dropout(text, training=training)
        
        # 双向GRU处理
        text = self.bi_gru(text, training=training)
        
        # 第二层GRU
        gru_output = self.gru(text, initial_state=hidden)

        if isinstance(gru_output, (list, tuple)):
            output = gru_output[0]
            state = gru_output[1]
        else:
            output = gru_output
            state = output[:, -1, :]

        # 返回最后时间步的输出和最终状态
        return output[:, -1, :], state

    def initialize_hidden_state(self):
        return tf.zeros((self.hparas['BATCH_SIZE'], self.hparas['RNN_HIDDEN_SIZE'] * 2))

In [None]:
# 改进的Generator - 使用DCGAN架构
class Generator(tf.keras.Model):
    """
    改进的生成器 - 使用转置卷积(反卷积)架构
    采用DCGAN的设计思路
    """
    def __init__(self, hparas):
        super(Generator, self).__init__()
        self.hparas = hparas
        
        # 文本处理层
        self.text_fc = tf.keras.layers.Dense(256)
        self.text_bn = tf.keras.layers.BatchNormalization()
        
        # 初始全连接层：将噪声和文本映射到特征图
        self.fc1 = tf.keras.layers.Dense(4 * 4 * 512, use_bias=False)
        self.bn1 = tf.keras.layers.BatchNormalization()
        
        # 反卷积层 - 逐步放大特征图
        # 4x4 -> 8x8
        self.deconv1 = tf.keras.layers.Conv2DTranspose(256, (4, 4), strides=(2, 2), 
                                                        padding='same', use_bias=False)
        self.bn2 = tf.keras.layers.BatchNormalization()
        
        # 8x8 -> 16x16
        self.deconv2 = tf.keras.layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), 
                                                        padding='same', use_bias=False)
        self.bn3 = tf.keras.layers.BatchNormalization()
        
        # 16x16 -> 32x32
        self.deconv3 = tf.keras.layers.Conv2DTranspose(64, (4, 4), strides=(2, 2), 
                                                        padding='same', use_bias=False)
        self.bn4 = tf.keras.layers.BatchNormalization()
        
        # 32x32 -> 64x64
        self.deconv4 = tf.keras.layers.Conv2DTranspose(3, (4, 4), strides=(2, 2), 
                                                        padding='same', use_bias=False)

    def call(self, text, noise_z, training=True):
        # 处理文本特征
        text = self.text_fc(text)
        text = self.text_bn(text, training=training)
        text = tf.nn.relu(text)
        
        # 连接噪声和文本
        x = tf.concat([noise_z, text], axis=1)
        
        # 全连接层
        x = self.fc1(x)
        x = self.bn1(x, training=training)
        x = tf.nn.relu(x)
        
        # 重塑为特征图 (batch, 4, 4, 512)
        x = tf.reshape(x, [-1, 4, 4, 512])
        
        # 反卷积层 1: 4x4 -> 8x8
        x = self.deconv1(x)
        x = self.bn2(x, training=training)
        x = tf.nn.relu(x)
        
        # 反卷积层 2: 8x8 -> 16x16
        x = self.deconv2(x)
        x = self.bn3(x, training=training)
        x = tf.nn.relu(x)
        
        # 反卷积层 3: 16x16 -> 32x32
        x = self.deconv3(x)
        x = self.bn4(x, training=training)
        x = tf.nn.relu(x)
        
        # 反卷积层 4: 32x32 -> 64x64
        x = self.deconv4(x)
        output = tf.nn.tanh(x)
        
        return x, output

In [None]:
# 改进的Discriminator - 使用DCGAN架构
class Discriminator(tf.keras.Model):
    """
    改进的判别器 - 使用卷积神经网络
    采用DCGAN的设计思路
    """
    def __init__(self, hparas):
        super(Discriminator, self).__init__()
        self.hparas = hparas
        
        # 图像卷积层
        # 64x64 -> 32x32
        self.conv1 = tf.keras.layers.Conv2D(64, (4, 4), strides=(2, 2), padding='same')
        self.dropout1 = tf.keras.layers.Dropout(0.3)
        
        # 32x32 -> 16x16
        self.conv2 = tf.keras.layers.Conv2D(128, (4, 4), strides=(2, 2), padding='same')
        self.bn2 = tf.keras.layers.BatchNormalization()
        self.dropout2 = tf.keras.layers.Dropout(0.3)
        
        # 16x16 -> 8x8
        self.conv3 = tf.keras.layers.Conv2D(256, (4, 4), strides=(2, 2), padding='same')
        self.bn3 = tf.keras.layers.BatchNormalization()
        self.dropout3 = tf.keras.layers.Dropout(0.3)
        
        # 8x8 -> 4x4
        self.conv4 = tf.keras.layers.Conv2D(512, (4, 4), strides=(2, 2), padding='same')
        self.bn4 = tf.keras.layers.BatchNormalization()
        self.dropout4 = tf.keras.layers.Dropout(0.3)
        
        self.flatten = tf.keras.layers.Flatten()
        
        # 文本处理层
        self.text_fc = tf.keras.layers.Dense(256)
        self.text_bn = tf.keras.layers.BatchNormalization()
        
        # 融合层
        self.fc1 = tf.keras.layers.Dense(512)
        self.fc2 = tf.keras.layers.Dense(1)

    def call(self, img, text, training=True):
        # 处理图像
        x = self.conv1(img)
        x = tf.nn.leaky_relu(x, alpha=0.2)
        x = self.dropout1(x, training=training)
        
        x = self.conv2(x)
        x = self.bn2(x, training=training)
        x = tf.nn.leaky_relu(x, alpha=0.2)
        x = self.dropout2(x, training=training)
        
        x = self.conv3(x)
        x = self.bn3(x, training=training)
        x = tf.nn.leaky_relu(x, alpha=0.2)
        x = self.dropout3(x, training=training)
        
        x = self.conv4(x)
        x = self.bn4(x, training=training)
        x = tf.nn.leaky_relu(x, alpha=0.2)
        x = self.dropout4(x, training=training)
        
        x = self.flatten(x)
        
        # 处理文本
        text = self.text_fc(text)
        text = self.text_bn(text, training=training)
        text = tf.nn.leaky_relu(text, alpha=0.2)
        
        # 融合图像和文本特征
        combined = tf.concat([x, text], axis=1)
        combined = self.fc1(combined)
        combined = tf.nn.leaky_relu(combined, alpha=0.2)
        
        logits = self.fc2(combined)
        output = tf.nn.sigmoid(logits)
        
        return logits, output

In [18]:
text_encoder = TextEncoder(hparas)
generator = Generator(hparas)
discriminator = Discriminator(hparas)

In [19]:
# This method returns a helper function to compute cross entropy loss
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

In [None]:
# 改进的优化器 - 使用不同的学习率
generator_optimizer = tf.keras.optimizers.Adam(hparas['LR_G'], beta_1=hparas['BETA_1'], beta_2=hparas['BETA_2'])
discriminator_optimizer = tf.keras.optimizers.Adam(hparas['LR_D'], beta_1=hparas['BETA_1'], beta_2=hparas['BETA_2'])

# 学习率调度器
lr_schedule_g = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate=hparas['LR_G'],
    decay_steps=1000,
    decay_rate=hparas['LR_DECAY'],
    staircase=True
)

lr_schedule_d = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate=hparas['LR_D'],
    decay_steps=1000,
    decay_rate=hparas['LR_DECAY'],
    staircase=True
)

In [None]:
# 改进的训练步骤 - 添加Gradient Penalty和更稳定的训练策略
@tf.function
def train_step(real_image, caption, hidden, use_wgan_gp=False):
    # random noise for generator
    noise = tf.random.normal(shape=[hparas['BATCH_SIZE'], hparas['Z_DIM']], mean=0.0, stddev=1.0)

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        # 编码文本
        text_embed, hidden = text_encoder(caption, hidden, training=True)
        
        # 生成假图像
        _, fake_image = generator(text_embed, noise, training=True)
        
        # 判别器判断真假
        real_logits, real_output = discriminator(real_image, text_embed, training=True)
        fake_logits, fake_output = discriminator(fake_image, text_embed, training=True)

        # 计算损失
        if use_wgan_gp:
            # WGAN-GP损失
            g_loss = generator_loss(fake_logits, use_wgan=True)
            d_loss = discriminator_loss(real_logits, fake_logits, use_wgan=True)
            
            # 添加梯度惩罚
            gp = gradient_penalty(discriminator, real_image, fake_image, text_embed)
            d_loss = d_loss + hparas['LAMBDA_GP'] * gp
        else:
            # 标准GAN损失（带标签平滑）
            g_loss = generator_loss(fake_logits, use_wgan=False)
            d_loss = discriminator_loss(real_logits, fake_logits, use_wgan=False)

    # 计算梯度
    grad_g = gen_tape.gradient(g_loss, generator.trainable_variables)
    grad_d = disc_tape.gradient(d_loss, discriminator.trainable_variables)
    
    # 梯度裁剪 - 防止梯度爆炸
    grad_g, _ = tf.clip_by_global_norm(grad_g, 5.0)
    grad_d, _ = tf.clip_by_global_norm(grad_d, 5.0)

    # 应用梯度
    generator_optimizer.apply_gradients(zip(grad_g, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(grad_d, discriminator.trainable_variables))

    return g_loss, d_loss

Visualiztion

In [26]:
def sample_generator(caption, batch_size):
    caption = np.asarray(caption)
    caption = caption.astype(int)
    dataset = tf.data.Dataset.from_tensor_slices(caption)
    dataset = dataset.batch(batch_size)
    return dataset

Training

高级技巧和额外的改进

In [None]:
# 图像质量评估函数
def calculate_inception_score(images, n_split=10, eps=1e-16):
    """
    计算Inception Score来评估生成图像的质量
    注意：这需要预训练的Inception模型
    """
    # 这里简化实现，实际应用中需要使用Inception v3
    pass

def calculate_fid_score(real_images, fake_images):
    """
    计算Frechet Inception Distance (FID)
    FID越低表示生成图像质量越好
    """
    # 简化版本，实际需要使用Inception特征
    pass

# 多样性检查 - 防止模式崩溃
def check_diversity(generated_images, threshold=0.9):
    """
    检查生成图像的多样性，防止mode collapse
    """
    # 计算图像之间的相似度
    n = len(generated_images)
    similarities = []
    
    for i in range(min(n, 10)):
        for j in range(i+1, min(n, 10)):
            img1 = generated_images[i].numpy().flatten()
            img2 = generated_images[j].numpy().flatten()
            
            # 计算余弦相似度
            similarity = np.dot(img1, img2) / (np.linalg.norm(img1) * np.linalg.norm(img2))
            similarities.append(similarity)
    
    avg_similarity = np.mean(similarities)
    print(f"Average similarity between generated images: {avg_similarity:.4f}")
    
    if avg_similarity > threshold:
        print("Warning: High similarity detected! Possible mode collapse.")
    else:
        print("Good diversity in generated images.")
    
    return avg_similarity

In [None]:
# 改进的训练函数 - 添加更多训练技巧
def train(dataset, epochs, use_wgan_gp=False, start_epoch=0, best_loss=None):
    # hidden state of RNN
    hidden = text_encoder.initialize_hidden_state()
    steps_per_epoch = int(hparas['N_SAMPLE']/hparas['BATCH_SIZE'])
    
    # 用于追踪最佳模型
    best_g_loss = best_loss if best_loss is not None else float('inf')
    
    # 训练历史记录
    history = {
        'g_loss': [],
        'd_loss': []
    }

    for epoch in range(start_epoch, hparas['N_EPOCH']):
        g_total_loss = 0
        d_total_loss = 0
        start = time.time()
        
        step = 0
        for image, caption in dataset:
            # 训练判别器多次（WGAN的常见做法）
            for _ in range(1 if not use_wgan_gp else 5):
                g_loss, d_loss = train_step(image, caption, hidden, use_wgan_gp)
            
            g_total_loss += g_loss
            d_total_loss += d_loss
            step += 1
            
            if step >= steps_per_epoch:
                break
        
        # 计算平均损失
        avg_g_loss = g_total_loss / steps_per_epoch
        avg_d_loss = d_total_loss / steps_per_epoch
        
        history['g_loss'].append(float(avg_g_loss))
        history['d_loss'].append(float(avg_d_loss))

        time_tuple = time.localtime()
        time_string = time.strftime("%m/%d/%Y, %H:%M:%S", time_tuple)

        print("Epoch {}/{}, gen_loss: {:.4f}, disc_loss: {:.4f}".format(
            epoch+1, hparas['N_EPOCH'],
            avg_g_loss,
            avg_d_loss))
        print('Time for epoch {} is {:.4f} sec'.format(epoch+1, time.time()-start))

        # 保存最佳模型
        if avg_g_loss < best_g_loss:
            best_g_loss = avg_g_loss
            checkpoint.save(file_prefix = checkpoint_prefix + '_best')
            print(f'  -> Best model saved with G loss: {best_g_loss:.4f}')

        # 定期保存模型
        if (epoch + 1) % hparas['SAVE_FREQ'] == 0:
            checkpoint.save(file_prefix = checkpoint_prefix)
            print(f'  -> Checkpoint saved at epoch {epoch+1}')

        # visualization
        if (epoch + 1) % hparas['PRINT_FREQ'] == 0:
            for caption in sample_sentence:
                fake_image = test_step(caption, sample_seed, hidden)
            save_images(fake_image, [ni, ni], 'samples/demo/train_{:02d}.jpg'.format(epoch))
    
    # 绘制训练曲线
    plt.figure(figsize=(10, 4))
    plt.subplot(1, 2, 1)
    plt.plot(history['g_loss'])
    plt.title('Generator Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    
    plt.subplot(1, 2, 2)
    plt.plot(history['d_loss'])
    plt.title('Discriminator Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    
    plt.tight_layout()
    plt.savefig('samples/demo/training_history.png')
    plt.show()
    
    return history

Evaluation

In [None]:
def testing_data_generator(caption, index):
    caption = tf.cast(caption, tf.float32)
    return caption, index

def testing_dataset_generator(batch_size, data_generator):
    test_data_path = os.path.join(data_path, 'testData.pkl')
    data = pd.read_pickle(test_data_path)
    captions = data['Captions'].values
    caption = []
    for i in range(len(captions)):
        caption.append(captions[i])
    caption = np.asarray(caption)
    caption = caption.astype(int)
    index = data['ID'].values
    index = np.asarray(index)

    dataset = tf.data.Dataset.from_tensor_slices((caption, index))
    dataset = dataset.map(data_generator, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    dataset = dataset.repeat().batch(batch_size)

    return dataset

In [None]:
data = pd.read_pickle(os.path.join(data_path, 'text2ImgData.pkl'))
captions = data['Captions'].values

NUM_TEST = len(captions)
EPOCH_TEST = int(NUM_TEST / hparas['BATCH_SIZE'])

In [36]:
if not os.path.exists('./inference/demo'):
    os.makedirs('./inference/demo')

In [None]:
# 加载最佳模型进行推理
# 选项1: 加载最佳模型
best_checkpoint = checkpoint_dir + '/ckpt_best-1'

# 选项2: 加载最新的checkpoint
# latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)

# 选项3: 加载特定epoch的checkpoint
# specific_checkpoint = checkpoint_dir + '/ckpt-10'

print("Available checkpoints:")
for ckpt in tf.train.get_checkpoint_state(checkpoint_dir).all_model_checkpoint_paths:
    print(f"  - {ckpt}")

# 尝试加载最佳模型
try:
    status = checkpoint.restore(best_checkpoint)
    print(f"\nSuccessfully loaded best model: {best_checkpoint}")
except:
    print(f"\nBest model not found, loading latest checkpoint...")
    latest = tf.train.latest_checkpoint(checkpoint_dir)
    if latest:
        checkpoint.restore(latest)
        print(f"Loaded: {latest}")
    else:
        print("No checkpoint found!")

<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus at 0x7adc3d8fb4a0>

In [None]:
#檢查是否缺失文件

import os

base_image_directory = '/content/drive/MyDrive/Colab Notebooks/dl/comp3/2025-datalab-cup-3-reverse-image-caption/102flowers'

missing_files = []

for i in range(1, 8190):  # From 00001 to 08189
    filename = f"image_{i:05d}.jpg"
    file_path = os.path.join(base_image_directory, filename)
    if not os.path.exists(file_path):
        missing_files.append(filename)

if missing_files:
    print(f"The following {len(missing_files)} files are missing from {base_image_directory}:")
    for missing_file in missing_files:
        print(missing_file)
else:
    print(f"All image files from image_00001.jpg to image_08189.jpg are present in {base_image_directory}.")

In [None]:
# 可视化生成结果对比
def visualize_results(num_examples=5):
    """
    可视化不同描述生成的图像
    """
    test_captions = [
        "the flower shown has yellow anther red pistil and bright red petals.",
        "this flower has petals that are yellow, white and purple and has dark lines",
        "the petals on this flower are white with a yellow center",
        "this flower has a lot of small round pink petals.",
        "this flower is orange in color, and has petals that are ruffled and rounded."
    ]
    
    fig, axes = plt.subplots(1, num_examples, figsize=(20, 4))
    
    hidden = text_encoder.initialize_hidden_state()
    
    for i, caption_text in enumerate(test_captions[:num_examples]):
        # 将文本转换为ID列表
        caption_ids = sent2IdList(caption_text)
        caption_batch = np.array([caption_ids] * hparas['BATCH_SIZE'])
        caption_batch = tf.constant(caption_batch, dtype=tf.int32)
        
        # 生成随机噪声
        noise = tf.random.normal([hparas['BATCH_SIZE'], hparas['Z_DIM']])
        
        # 生成图像
        fake_image = test_step(caption_batch, noise, hidden)
        
        # 显示第一张生成的图像
        axes[i].imshow(fake_image[0].numpy() * 0.5 + 0.5)
        axes[i].axis('off')
        # 截取前30个字符作为标题
        axes[i].set_title(caption_text[:30] + '...', fontsize=8)
    
    plt.tight_layout()
    plt.savefig('samples/demo/comparison.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    print("Visualization saved to samples/demo/comparison.png")

# 运行可视化（在训练后）
# visualize_results()