In [16]:
import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout, LeakyReLU, Conv2DTranspose, Conv2D
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.datasets import mnist

# 设置参数
img_rows, img_cols, channels = 28, 28, 1  # 图像尺寸和通道数
noise_dim = 100  # 噪声向量维度
batch_size = 64
epochs = 12000
sample_interval = 500  # 保存模型生成图像的频率


In [17]:
# 定义生成器
def build_generator():
    model = Sequential()
    model.add(Dense(128 * 7 * 7, activation='relu', input_dim=noise_dim))
    model.add(Reshape((7, 7, 128)))
    model.add(Conv2DTranspose(128, kernel_size=3, strides=2, padding='same', activation='relu'))
    model.add(Conv2DTranspose(64, kernel_size=3, strides=1, padding='same', activation='relu'))
    model.add(Conv2DTranspose(channels, kernel_size=3, strides=2, padding='same', activation='sigmoid'))
    return model

# 定义判别器
def build_discriminator():
    model = Sequential()
    model.add(Conv2D(64, kernel_size=3, strides=2, padding='same', input_shape=(img_rows, img_cols, channels)))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Conv2D(128, kernel_size=3, strides=2, padding='same'))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Flatten())
    model.add(Dense(1, activation='sigmoid'))
    return model


In [18]:
# 构建生成器和判别器
generator = build_generator()
discriminator = build_discriminator()

# 编译判别器
discriminator.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5), metrics=['accuracy'])

# 构建整体模型
z = Input(shape=(noise_dim,))
img = generator(z)
discriminator.trainable = False
real_or_fake = discriminator(img)
combined = Model(z, real_or_fake)
combined.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5))


In [12]:
def load_quickdraw_data():
    import urllib.request
    import json

    url = "https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/cat.npy"
    urllib.request.urlretrieve(url, 'cat.npy')
    cats = np.load('cat.npy')

    # 数据预处理
    cats = cats.reshape(-1, img_rows, img_cols, channels)
    cats = cats.astype('float32') / 255
    return cats

cats = load_quickdraw_data()


In [19]:
def train():
    valid = np.ones((batch_size, 1))
    fake = np.zeros((batch_size, 1))

    for epoch in range(epochs):
        # 训练判别器
        idx = np.random.randint(0, cats.shape[0], batch_size)
        real_imgs = cats[idx]

        noise = np.random.normal(0, 1, (batch_size, noise_dim))
        gen_imgs = generator.predict(noise)

        d_loss_real = discriminator.train_on_batch(real_imgs, valid)
        d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

        # 训练生成器
        noise = np.random.normal(0, 1, (batch_size, noise_dim))
        g_loss = combined.train_on_batch(noise, valid)

        # 打印进度和保存样本
        if epoch % sample_interval == 0:
            print("Epoch: %d, D loss: %f, G loss: %f" % (epoch, d_loss[0], g_loss))
            sample_images(epoch)


In [20]:
def sample_images(epoch):
    r, c = 5, 5
    noise = np.random.normal(0, 1, (r * c, noise_dim))
    gen_imgs = generator.predict(noise)

    gen_imgs = 0.5 * gen_imgs + 0.5

    fig, axs = plt.subplots(r, c)
    cnt = 0
    for i in range(r):
        for j in range(c):
            axs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap='gray')
            axs[i, j].axis('off')
            cnt += 1
    fig.savefig("cat_%d.png" % epoch)
    plt.close()


In [None]:
train()
