<a href="https://colab.research.google.com/github/zfr-1/alien_invasion/blob/master/GAN%E5%85%A5%E9%97%A8%E7%A4%BA%E4%BE%8B_.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import tensorflow as tf
from tensorflow import keras 
from tensorflow.keras import layers
import matplotlib.pyplot as plt
import numpy as np
import glob
%matplotlib inline
import os
from time import time

In [0]:
tf.__version__

In [0]:
(train_images, train_labels), (_, _) = keras.datasets.mnist.load_data()

In [0]:
train_images.shape

In [0]:
train_images.dtype

In [0]:
train_images = train_images.reshape(train_images.shape[0],28, 28,1)

In [0]:
train_images.shape

In [0]:
train_images = train_images.astype('float32')

In [0]:
train_images = (train_images-127.5)/127.5

In [0]:
BATCH_SIZE = 256
BUFFLE_SIZE = 60000

In [0]:
datasets = tf.data.Dataset.from_tensor_slices(train_images)

In [0]:
datasets

In [0]:
datasets = datasets.shuffle(BUFFLE_SIZE).batch(BATCH_SIZE)

In [0]:
datasets

## 编写生成器模型

In [0]:
def generator_model():
    """模型的创建"""
    model = keras.Sequential()
    model.add(layers.Dense(256, input_shape=(100,), use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU()) # 激活层
    
    model = keras.Sequential()
    model.add(layers.Dense(512,use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU()) # 激活层
    
    model.add(layers.Dense(28*28*1, use_bias=False, activation='tanh'))
    model.add(layers.BatchNormalization())
    
    model.add(layers.Reshape((28, 28, 1)))
    
    return model
    
    
    

## 创建判别器模型

In [0]:
def discriminator_model():
    """生成判别器模型"""
    model = keras.Sequential()
    model.add(layers.Flatten())
    
    
    model.add(layers.Dense(512, use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU()) # 激活层
    
    
    model.add(layers.Dense(256, use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU()) # 激活层
    
    
    model.add(layers.Dense(1))
    return model

In [0]:
# 初始化交叉熵函数
cross_entropy = keras.losses.BinaryCrossentropy(from_logits=True)

In [0]:
def discriminate_loss(real_out, fake_out):
    """定义判别模型的损失函数"""
    real_loss = cross_entropy(tf.ones_like(real_out), real_out)
    fake_loss = cross_entropy(tf.zeros_like(fake_out), fake_out)
    return real_loss+fake_loss

In [0]:
def generator_loss(fake_out):
    """定义生成模型的损失函数"""
    return cross_entropy(tf.ones_like(fake_out), fake_out)

In [0]:
# 定义生成模型的优化器
generator_opt = keras.optimizers.Adam(learning_rate=1e-4)
# 定义判别模型的优化器
discriminator_opt = keras.optimizers.Adam(learning_rate=1e-4) 

In [0]:
EPOCHS = 100
noise_dim = 100

num_exp_to_generate = 16
seed = tf.random.normal([num_exp_to_generate, noise_dim])

In [0]:
discriminator = discriminator_model()

In [0]:
generator = generator_model()

In [0]:
def train_step(images):
    """训练一个批次的图片"""
    noise = tf.random.normal([BATCH_SIZE, noise_dim])
    
     # 记录梯度
    with  tf.GradientTape() as gen_tape , tf.GradientTape() as disc_tape:
        real_out = discriminator(images, training=True)
        gen_img = generator(noise, training=True)
        fake_out = discriminator(gen_img, training=True)
        gen_loss = generator_loss(fake_out)
        disc_loss = discriminate_loss(real_out, fake_out)
    gen_gradient = gen_tape.gradient(gen_loss, generator.trainable_variables)
    disc_gradient = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
    generator_opt.apply_gradients(zip(gen_gradient, generator.trainable_variables))
    discriminator_opt.apply_gradients(zip(disc_gradient, discriminator.trainable_variables))


In [0]:
def generate_plot_images(gen_model, test_noise):
    pre_images = gen_model(test_noise, training=False)
    fig = plt.figure(figsize=(4,4))
    for i in range(pre_images.shape[0]):
        plt.subplot(4,4,i+1)
        plt.imshow((pre_images[i, :,:,0]+1)/2, cmap='gray')
        plt.axis('off')
    plt.show()

In [0]:
def train(dataset, epochs):
    epoch_start_time = time()
    start_time = epoch_start_time
    for epoch in range(epochs):
        for batch in dataset:
            train_step(batch)
            print('.', end='')
        epoch_end_time = time()
        print("第", epoch+1, "轮训练已完成")
        print("耗时:", epoch_end_time-epoch_start_time)
        epoch_start_time = epoch_end_time
        generate_plot_images(generator, seed)
    print("所有训练已完成，共耗时：", epoch_end_time-start_time)

In [0]:
train(datasets, EPOCHS)