In [1]:
import numpy as np
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense,Activation,Flatten,Reshape
from keras.layers import Conv2D,Conv2DTranspose,UpSampling2D
from keras.layers import LeakyReLU ,Dropout
from keras.layers import BatchNormalization

from keras.optimizers import RMSprop
import matplotlib.pyplot as plt

Using TensorFlow backend.


In [2]:
class DCGAN(object):
    def __init__(self, img_rows=28,img_cols=28,channel=1):
        # 初始化图片的行列通道数
        self.img_rows = img_rows
        self.img_cols = img_cols
        self.channel = channel
        self.D = None    # discriminator 判别器
        self.G = None    # Generator 生成器
        self.AM = None   # adversarial 对抗模型
        self.DM = None   # discriminator model 判别模型
        
    #判别模型
    def discriminator(self):
        if self.D:
            return self.D
        self.D = Sequential()
        # 定义通道数为64
        depth = 64
        # dropout系数
        dropout = 0.4
        # 输入28*28*1
        input_shape = (self.img_rows,self.img_cols,self.channel)
        # 输出 14*14*64
        self.D.add(Conv2D(depth*1, 5, strides=2, input_shape=input_shape,padding='same'))
        self.D.add(LeakyReLU(alpha=0.2))
        self.D.add(Dropout(dropout))
        
        # 输出 7*7*128
        self.D.add(Conv2D(depth*2, 5, strides=2, padding='same'))
        self.D.add(LeakyReLU(alpha=0.2))
        self.D.add(Dropout(dropout))
        
        # 输出 4*4*256
        self.D.add(Conv2D(depth*4, 5, strides=2, padding='same'))
        self.D.add(LeakyReLU(alpha=0.2))
        self.D.add(Dropout(dropout))
        
        # 输出 4*4*512
        self.D.add(Conv2D(depth*8, 5, strides=1, padding='same'))
        self.D.add(LeakyReLU(alpha=0.2))
        self.D.add(Dropout(dropout))
        
        # 添加全连接层
        self.D.add(Flatten())
        self.D.add(Dense(1))
        self.D.add(Activation('sigmoid'))
        self.D.summary()
        return self.D
    
    # 生成模型
    def generator(self):
        if self.G:
            return self.G
        self.G = Sequential()
        # dropout系数
        dropout = 0.4
        # 通道数为256
        depth = 64*4
        # 初始平面大小设置
        dim = 7
        # 全连接层，100个随机噪声数据， 7*7*256个神经网络
        self.G.add(Dense(dim*dim*depth,input_dim=100))
        self.G.add(BatchNormalization(momentum=0.9))
        self.G.add(Activation('relu'))
        # 把一维向量变成3维数据(7,7,256)
        self.G.add(Reshape((dim,dim,depth)))
        self.G.add(Dropout(dropout))
        
        # 用法和MaxPooling2D基本相反， 比如 ，UpSampling2D(Size=(2,2))
        # 就相当于将输入图片的长宽各拉升一倍，整个图片放大了
        # 上采样，采样后得到的格式数据为(14,14,256)
        self.G.add(UpSampling2D(size=(2,2)))
        # 转置卷积，得到数据格式(14,14,128)
        self.G.add(Conv2DTranspose(int(depth/2), 5, padding='same'))
        self.G.add(BatchNormalization(momentum=0.9))
        self.G.add(Activation('relu'))
        
        # 上采样，采样后的数据格式为（28,28,128）
        self.G.add(UpSampling2D(size=(2,2)))
        # 转置卷积，得到的数据格式为（28,28,64）
        self.G.add(Conv2DTranspose(int(depth/4), 5, padding='same'))
        self.G.add(BatchNormalization(momentum=0.9))
        self.G.add(Activation('relu'))
        
        # 转置卷积，得到的数据格式为(28,28,32)
        self.G.add(Conv2DTranspose(int(depth/8), 5, padding='same'))
        self.G.add(BatchNormalization(momentum=0.9))
        self.G.add(Activation('relu'))
        
        # 转置卷积， 得到的数据格式为(28,28,1)
        self.G.add(Conv2DTranspose(1, 5, padding='same'))
        self.G.add(Activation('sigmoid'))
        self.G.summary()
        
        return self.G
        
    # 定义判别模型
    def discriminator_model(self):
        if self.DM:
            return self.DM
        # 定义优化器
        optimizer = RMSprop(lr=0.0002, decay=6e-8)
        # 构建模型
        self.DM = Sequential()
        self.DM.add(self.discriminator())
        self.DM.compile(loss='binary_crossentropy',optimizer=optimizer,metrics=['accuracy'])
        return self.DM
    
    # 定义生成对抗模型
    def adversarial_model(self):
        if self.AM:
            return self.AM
        # 定义优化器
        optimizer = RMSprop(lr=0.0001,decay=3e-8)
        # 构建模型
        self.AM = Sequential()
        self.AM.add(self.generator())
        self.AM.add(self.discriminator())
        self.AM.compile(loss='binary_crossentropy',optimizer=optimizer,metrics=['accuracy'])
        return self.AM
    

In [3]:
class MNIST_DCGAN(object):
    def __init__(self):
        # 图片的行数
        self.img_rows = 28
        # 图片的列数
        self.img_cols = 28
        # 图片的通道数
        self.channel = 1
        
        # 载入数据 
        (x_train,y_train),(x_test,y_test) = mnist.load_data()
        # （60000,28,28）
        self.x_train = x_train/255.0
        # 改变数据格式(samples , row, col, channel)
        self.x_train = self.x_train.reshape(-1,self.img_rows,self.img_cols,self.channel).astype(np.float32)
        
        # 实例化 DCGAN
        self.DCGAN = DCGAN()
        # 定义判别器模型
        self.discriminator = self.DCGAN.discriminator_model()
        # 定义对抗模型
        self.adversarial = self.DCGAN.adversarial_model()
        # 定义生成器
        self.generator = self.DCGAN.generator()
        
    #训练模型
    def train(self, train_steps=2000, batch_size=256, save_interval=0):
        noise_input = None
        if save_interval > 0:
            # 生成16个100维的噪声数据
            noise_input = np.random.uniform(-1.0, 1.0, size=[16,100])
        for i in range(train_steps):
        # 训练判别器，提升判断能力
            # 随机得到一个batch的图片数据
            images_train = self.x_train[np.random.randint(0, self.x_train.shape[0],size=batch_size), :, :, :]
            # 随机生成一个batch的噪声数据
            noise = np.random.uniform(-1.0, 1.0, size=[batch_size,100])
            # 生成伪造的图片数据
            image_fake = self.generator.predict(noise)
            # 合并一个batch真实图片和一个batch的伪造图片
            x = np.concatenate((images_train,image_fake))
            # 定义标签，真实数据的标签为1，伪造数据的标签为0
            y = np.ones([2*batch_size,1])
            y[batch_size:,:] = 0
            # 把数据放到判别器中进行判断
            d_loss = self.discriminator.train_on_batch(x,y)
            
         # 训练对抗模型，提升生成器的造假能力
            # 标签都定义为1
            y = np.ones([batch_size,1])
            # 生成一个batch的噪声数据
            noise = np.random.uniform(-1.0, 1.0, size=[batch_size,100])
            # 训练对抗模型
            a_loss = self.adversarial.train_on_batch(noise,y)
            # 打印 判别器的loss和准确率，以及对抗模型的loss和准确率
            log_mesg = "%d: [D loss: %f, acc: %f]" % (i,d_loss[0],d_loss[1])
            log_mesg = "%s: [A loss: %f, acc: %f]" % (log_mesg,a_loss[0],a_loss[1])
            print(log_mesg)
            # 如果需要保存图片
            if save_interval > 0:
                # 每save_interval 次保存一次
                if (i+1) % save_interval == 0:
                    self.plot_images(save2file=True, samples=noise_input.shape[0],noise=noise_input,step=i+1)

    # 保存图片
    def plot_images(self, save2file=False, fake=True, samples=16,noise=None, step=0):
        filename = 'mnist.png'
        if fake:
            if noise is None:
                noise = np.random.uniform(-1.0, 1.0, size=[samples,100])
            else:
                filename = 'mnist_%d.png' %(step)
            # 生成伪造图片数据
            images = self.generator.predict(noise)
        else:
            # 获得真实数据
            i = np.random.randint(0, self.x_train.shape[0],samples)
            images = self.x_train[i, :, :, :]
            
        plt.figure(figsize=(10,10))
        # 生成16张图片
        for i in range(images.shape[0]):
            plt.subplot(4, 4, i+1)
            image = images[i, :, :, :]
            image = np.reshape(image,[self.img_rows, self.img_cols])
            plt.imshow(image,cmap='gray')
            plt.axis('off')
        # 保存图片
        if save2file:
            plt.savefig(filename)
            plt.close('all')
        else:
            plt.show()
            


In [None]:
mnist_dcgan  = MNIST_DCGAN()
# 训练
mnist_dcgan.train(train_steps=10000,batch_size=256,save_interval=500)

In [None]:
# 分别查看真实数据与造假数据
mnist_dcgan.plot_images(fake=False)
mnist_dcgan.plot_images(fake=True)

In [None]:
# 模型存储
mnist_dcgan.generator.save('generator.h5')
mnist_dcgan.discriminator.save('discriminator.h5')
mnist_dcgan.adversarial.save('adversarial.h5')