In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
# %matplotlib inline
import numpy as np
import glob
import tensorflow.keras.datasets.mnist as mnist

In [None]:
# from numba import cuda
# cuda.select_device(0)
# cuda.close()

gpu = tf.config.experimental.list_physical_devices(device_type='GPU')
tf.config.experimental.set_memory_growth(gpu[0], True)

In [None]:
# 读取数据
(x_train,_),(x_test,test_lable)=tf.keras.datasets.mnist.load_data()
print(x_train.shape,y_train.shape,x_train.dtype)

# 转换数据类型
x_train=((x_train.astype('float32')-127.5)/127.5).reshape(x_train.shape[0],28,28,1)
print(x_train.shape)

#  定义datasets
batch_szie=256
shuffle_size=x_train.shape[0]
noise_dim=100 # 输入G中的初始向量长度。
dataset_train = tf.data.Dataset.from_tensor_slices(train_image)
dataset_test = tf.data.Dataset.from_tensor_slices((test_image, test_lable))
dataset_train = dataset_train.shuffle(shuffle_size).batch(BATCH_SIZE)
dataset_test = dataset_test.repeat().shuffle(shuffle_size).batch(BATCH_SIZE)

dataset_train = datasets.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
dataset_test = datasets.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
# Dataset.prefetch() 方法，使得我们可以让数据集对象 Dataset 在训练时预取出若干个元素，
# 使得在 GPU 训练的同时 CPU 可以准备数据，从而提升训练流程的效率
# 使用方法和Dataset.batch() 、 Dataset.shuffle() 等非常类似

In [None]:
def generator_model():
    seed = layers.Input(shape=((noise_dim,)))
    
    x = layers.Dense(3*3*256, use_bias=False)(seed)
    x = layers.Reshape((3, 3, 256))(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(alpha=0.2)(x)
    
    x = layers.Conv2DTranspose(128, (3, 3), strides=(2, 2), use_bias=False)(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(alpha=0.2)(x)     #  7*7

    x = layers.Conv2DTranspose(64, (3, 3), strides=(2, 2), padding='same', use_bias=False)(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(alpha=0.2)(x)    #   14*14

    x = layers.Conv2DTranspose(1, (3, 3), strides=(2, 2), padding='same')(x)  # 28*28*1
    
    x = tf.tanh(x)
    
    model = tf.keras.Model(inputs=seed, outputs=x)  
    
    return model


def discriminator_model():
    
    image = tf.keras.Input(shape=((28, 28, 1)))
    
    x = layers.Conv2D(32, (3, 3), strides=(2, 2), padding='same')(image)
    x = layers.LeakyReLU(alpha=0.2)(x)
    x = layers.Dropout(0.5)(x)
      
    x = layers.Conv2D(64, (3, 3), strides=(2, 2), padding='same', use_bias=False)(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(alpha=0.2)(x)
    x = layers.Dropout(0.5)(x)
    
    x = layers.Conv2D(128, (3, 3), strides=(2, 2), padding='same')(x)
#    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(alpha=0.2)(x)
    x = layers.Dropout(0.5)(x)
    
    x = layers.GlobalAveragePooling2D()(x)
    
    x = layers.Dense(11, activation='softmax')(x)  # 一共十类，输出11类。第11类为生成图像的类别
    
    model = tf.keras.Model(inputs=image, outputs=x)
    
    return model


generator = generator_model()
discriminator = discriminator_model()
cross_entropy = tf.keras.losses.SparseCategoricalCrossentropy()
binary_cross_entropy = tf.keras.losses.BinaryCrossentropy()

In [None]:
def discriminator_loss(d_label_real_output, label, d_unlabel_real_output, d_fake_output):
    label_real_loss = cross_entropy(label, d_label_real_output[:, :-1])
    
    unlabel_real_loss = binary_cross_entropy(tf.zeros_like(d_unlabel_real_output[:, -1]), 
                                             d_unlabel_real_output[:, -1])
    
    fake_loss = binary_cross_entropy(tf.ones_like(d_fake_output[:, -1]), d_fake_output[:, -1])
    
    return  label_real_loss + unlabel_real_loss + fake_loss

def generator_loss(d_fake_output):
#    fake_loss = -tf.reduce_mean(tf.math.log(d_fake_output[:, -1] + 1e-07))
#    input_real = tf.cast(input_real, tf.float32)
    fake_loss = binary_cross_entropy(tf.zeros_like(d_fake_output[:, -1]), d_fake_output[:, -1])
#    l2_loss = tf.reduce_mean(tf.math.square(input_real - g_output))
    return fake_loss


generator_optimizer = tf.keras.optimizers.Adam(1e-3, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-5, beta_1=0.5)

In [None]:
@tf.function
def train_step(input_label_real, labels, input_unlabel_real):
    '''
    input_label_real: 为x_test
    labels: 为y_test
    input_unlabel_real:为x_train
    '''
    
    noise = tf.random.normal([BATCH_SIZE, noise_dim])
    
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        
        generated_images = generator(noise, training=True)
        d_fake_output = discriminator(generated_images, training=True) 
        
        d_label_real_output = discriminator(input_label_real, training=True)
        
        d_unlabel_real_output = discriminator(input_unlabel_real, training=True)
        
        gen_loss = generator_loss(d_fake_output)
        disc_loss = discriminator_loss(d_label_real_output, labels, d_unlabel_real_output, d_fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
    
    return gen_loss, disc_loss

noise_dim = 100
num = 10
noise_seed = tf.random.normal([num, noise_dim])

def generate_and_save_images(model, noise_input,  epoch):
    print('Epoch:', epoch+1)
  # Notice `training` is set to False.
  # This is so all layers run in inference mode (batchnorm).
    predictions = model(noise_input, training=False)
    predictions = np.squeeze(predictions)
    fig = plt.figure(figsize=(10, 1))

    for i in range(predictions.shape[0]):
        plt.subplot(1, 10, i+1)
        plt.imshow((predictions[i, :, :] + 1)/2, cmap='gray')
        plt.axis('off')

#    plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
    plt.show()

epoch_loss_avg_gen = tf.keras.metrics.Mean('g_loss')
epoch_loss_avg_disc = tf.keras.metrics.Mean('d_loss')

g_loss_results = []
d_loss_results = []

In [None]:
def train(dataset_with_label, dataset_no_label, epochs):
    for epoch in range(epochs):
        for (label_image_batch, label_batch), unlabel_image_batch in zip(dataset_with_label, dataset_no_label):  # test有label，train无label
            g_loss, d_loss = train_step(label_image_batch, label_batch, unlabel_image_batch)
            epoch_loss_avg_gen(g_loss)
            epoch_loss_avg_disc(d_loss)
        print()
        g_loss_results.append(epoch_loss_avg_gen.result())
        d_loss_results.append(epoch_loss_avg_disc.result())
        
        epoch_loss_avg_gen.reset_states()
        epoch_loss_avg_disc.reset_states()
        
        if epoch%20 == 0:
            generate_and_save_images(generator,
                                     noise_seed,
                                     epoch)

    generate_and_save_images(generator,
                            noise_seed,
                            epoch)

In [None]:
EPOCHS = 500
train(dataset_test, dataset_train, EPOCHS)

In [None]:
plt.plot(range(1, len(g_loss_results)+1), g_loss_results, label='g_loss')
plt.plot(range(1, len(d_loss_results)+1), d_loss_results, label='d_loss')
plt.legend()

In [None]:
# 保存模型
generator.save('gen_model/generate_SGAN.h5')