<h1>DCGAN 代码实例

In [15]:
from __future__ import division
import os
import time
import math
from glob import glob
import numpy as np
from six.moves import xrange
from ops_z import *
import scipy.misc

In [16]:
from utils import pp, visualize, to_json, show_all_variables
import tensorflow as tf

计算卷积层大小的方法

In [17]:
def conv_out_size_same(size, stride):
    return int(math.ceil(float(size) / float(stride)))

<h1>DCGAN的网络结构

sess: TensorFlow 会话

batch_size: 训练数据集批次大小，需要在训练前指定完毕 

y_dim: (可选参数) 用于加速训练的参数"Highway skip"

z_dim: (可选参数) 初始向量z的维度，默认为【100】

gf_dim: (可选参数) 生成器第一层卷积的卷积和维度，默认为【64】

df_dim: (可选参数) 判别器第一层卷积的卷积和维度，默认为【64】

gfc_dim: (可选参数) 生成器全联接层的神经元个数，默认为【1024】

dfc_dim: (可选参数) 辨别器全联接层的神经元个数，默认为【1024】

c_dim: (可选参数) 图片的颜色维度，灰度图为【1】，默认为【3】

In [18]:
class DCGAN(object):
    def __init__(self, sess, input_height=108, input_width=108, is_crop=True,
         batch_size=64, sample_num = 64, output_height=64, output_width=64,
         y_dim=None, z_dim=100, gf_dim=64, df_dim=64,
         gfc_dim=1024, dfc_dim=1024, c_dim=3, dataset_name='default',
         input_fname_pattern='*.jpg', checkpoint_dir=None, sample_dir=None):
        
        self.sess = sess
        self.is_crop = is_crop
        self.is_grayscale = (c_dim == 1)

        # 设置批大小
        # sample_num是做什么的？？？？？？？？？？？？？？？
        self.batch_size = batch_size
        self.sample_num = sample_num
        
        self.input_height = input_height
        self.input_width = input_width
        self.output_height = output_height
        self.output_width = output_width
        
        
        
        self.y_dim = y_dim
        self.z_dim = z_dim

        # 设置首层卷积核维度
        self.gf_dim = gf_dim
        self.df_dim = df_dim

        # 设置全连接层神经元个数
        self.gfc_dim = gfc_dim
        self.dfc_dim = dfc_dim
        
        # 设置图片色彩维度
        self.c_dim = c_dim
        
        # 使用batch_norm对数据进行预处理
        # 传统GAN对趋于无穷的正态分布取样，会影响收敛的速度
        # batch normalization : deals with poor initialization helps gradient flow
        self.d_bn1 = batch_norm(name='d_bn1')
        self.d_bn2 = batch_norm(name='d_bn2')

        if not self.y_dim:
            self.d_bn3 = batch_norm(name='d_bn3')

        self.g_bn0 = batch_norm(name='g_bn0')
        self.g_bn1 = batch_norm(name='g_bn1')
        self.g_bn2 = batch_norm(name='g_bn2')
        
        if not self.y_dim:
            self.g_bn3 = batch_norm(name='g_bn3')
        
        # 设置数据集名称／目录
        # 设置训练数据图片后缀名*.jpg
        # 设置检查点文件夹
        self.dataset_name = dataset_name
        self.input_fname_pattern = input_fname_pattern
        self.checkpoint_dir = checkpoint_dir
        
        self.build_model()
        
    def build_model(self):
        
        if self.y_dim:
            self.y= tf.placeholder(tf.float32, [self.batch_size, self.y_dim], name='y')

        # is_crop 是否代表CenterCrop
        if self.is_crop:
            image_dims = [self.output_height, self.output_width, self.c_dim]
        else:
            image_dims = [self.input_height, self.input_width, self.c_dim]

        # 定义输入变量，分别为真实的训练图片输入 和 生成器生成的图片输入
        self.inputs = tf.placeholder(
                          tf.float32, [self.batch_size] + image_dims, name='real_images')
        self.sample_inputs = tf.placeholder(
                          tf.float32, [self.sample_num] + image_dims, name='sample_inputs')
        
        inputs = self.inputs
        sample_inputs = self.sample_inputs
        
        # 输入，histogram_summary的意义？
        self.z = tf.placeholder(tf.float32, [None, self.z_dim], name='z')
        self.z_sum = histogram_summary("z", self.z)
        
        # y作为除了z以外，另一个输入维度，如果y为None，则不使用y
        if self.y_dim:
            self.G = self.generator(self.z, self.y)
            self.D, self.D_logits = self.discriminator(inputs, self.y, reuse=False)

            self.sampler = self.sampler(self.z, self.y)
            self.D_, self.D_logits_ = self.discriminator(self.G, self.y, reuse=True)
        else:
            self.G = self.generator(self.z)
            self.D, self.D_logits = self.discriminator(inputs)

            self.sampler = self.sampler(self.z)
            self.D_, self.D_logits_ = self.discriminator(self.G, reuse=True)
        
        # 建立图表变量，用于可视化显示
        self.d_sum = histogram_summary("d", self.D)
        self.d__sum = histogram_summary("d_", self.D_)
        self.G_sum = image_summary("G", self.G)
        
        # 封装定义交叉熵LOSS计算函数
        def sigmoid_cross_entropy_with_logits(x, y):
            try:
                return tf.nn.sigmoid_cross_entropy_with_logits(logits=x, labels=y)
            except:
                return tf.nn.sigmoid_cross_entropy_with_logits(logits=x, targets=y)
        
        # 判别器对真实训练图片的判别损失
        self.d_loss_real = tf.reduce_mean(
        sigmoid_cross_entropy_with_logits(self.D_logits, tf.ones_like(self.D)))
        # 判别器对生成器生成图片的判别损失
        self.d_loss_fake = tf.reduce_mean(
        sigmoid_cross_entropy_with_logits(self.D_logits_, tf.zeros_like(self.D_)))
        # 生成器损失？？？？？？？？？？？？？？？
        self.g_loss = tf.reduce_mean(
        sigmoid_cross_entropy_with_logits(self.D_logits_, tf.ones_like(self.D_)))
        
        self.d_loss_real_sum = scalar_summary("d_loss_real", self.d_loss_real)
        self.d_loss_fake_sum = scalar_summary("d_loss_fake", self.d_loss_fake)
    
        # 计算判别器总误差
        self.d_loss = self.d_loss_real + self.d_loss_fake
        
        self.g_loss_sum = scalar_summary("g_loss", self.g_loss)
        self.d_loss_sum = scalar_summary("d_loss", self.d_loss)

        t_vars = tf.trainable_variables()
        
        self.d_vars = [var for var in t_vars if 'd_' in var.name]
        self.g_vars = [var for var in t_vars if 'g_' in var.name]

        # 保存模型
        self.saver = tf.train.Saver()

    def train(self, config):
        """Train DCGAN"""
        if config.dataset == 'mnist':
            data_X, data_y = self.load_mnist()
        else:
            data = glob(os.path.join("./data", config.dataset, self.input_fname_pattern))
        #np.random.shuffle(data)
        
        # 优化器
        d_optim = tf.train.AdamOptimizer(config.learning_rate, beta1=config.beta1) \
              .minimize(self.d_loss, var_list=self.d_vars)
        g_optim = tf.train.AdamOptimizer(config.learning_rate, beta1=config.beta1) \
              .minimize(self.g_loss, var_list=self.g_vars)
        
        # 初始化变量
        try:
            #tf.global_variables_initializer().run()
            self.sess.run(tf.global_variables_initializer())
        except:
            #tf.initialize_all_variables().run()
            self.sess.run(tf.initialize_all_variables())
        
        self.g_sum = merge_summary([self.z_sum, self.d__sum,
                self.G_sum, self.d_loss_fake_sum, self.g_loss_sum])
        self.d_sum = merge_summary(
                [self.z_sum, self.d_sum, self.d_loss_real_sum, self.d_loss_sum])
        self.writer = SummaryWriter("./logs", self.sess.graph)
        
        # 生成随机输入向量Z
        sample_z = np.random.uniform(-1, 1, size=(self.sample_num , self.z_dim))
    
        # 加载训练数据集
        if config.dataset == 'mnist':
            sample_inputs = data_X[0:self.sample_num]
            sample_labels = data_y[0:self.sample_num]
        else:
            sample_files = data[0:self.sample_num]
            sample = [
              get_image(sample_file,
                    input_height=self.input_height,
                    input_width=self.input_width,
                    resize_height=self.output_height,
                    resize_width=self.output_width,
                    is_crop=self.is_crop,
                    is_grayscale=self.is_grayscale) for sample_file in sample_files]
            if (self.is_grayscale):
                sample_inputs = np.array(sample).astype(np.float32)[:, :, :, None]
            else:
                sample_inputs = np.array(sample).astype(np.float32)
        counter = 1
        start_time = time.time()
        
        # 加载已经训练过的节点模型
        could_load, checkpoint_counter = self.load(self.checkpoint_dir)
        if could_load:
            counter = checkpoint_counter
            print(" [*] Load SUCCESS")
        else:
            print(" [!] Load failed...")

        # ************EPOCH大循环，在这里开始********************
        for epoch in xrange(config.epoch):
            if config.dataset == 'mnist':
                batch_idxs = min(len(data_X), config.train_size) // config.batch_size
            else:      
                data = glob(os.path.join(
                      "./data", config.dataset, self.input_fname_pattern))
                batch_idxs = min(len(data), config.train_size) // config.batch_size
            
            for idx in xrange(0, batch_idxs):
                if config.dataset == 'mnist':
                    batch_images = data_X[idx*config.batch_size:(idx+1)*config.batch_size]
                    batch_labels = data_y[idx*config.batch_size:(idx+1)*config.batch_size]
                else:
                    batch_files = data[idx*config.batch_size:(idx+1)*config.batch_size]
                    batch = [
                        get_image(batch_file,
                            input_height=self.input_height,
                            input_width=self.input_width,
                            resize_height=self.output_height,
                            resize_width=self.output_width,
                            is_crop=self.is_crop,
                            is_grayscale=self.is_grayscale) for batch_file in batch_files]
                    if (self.is_grayscale):
                        batch_images = np.array(batch).astype(np.float32)[:, :, :, None]
                    else:
                        batch_images = np.array(batch).astype(np.float32)
                batch_z = np.random.uniform(-1, 1, [config.batch_size, self.z_dim]).astype(np.float32)
                
                if config.dataset == 'mnist':
                    # 更新判别器网络
                    _, summary_str = self.sess.run([d_optim, self.d_sum],
                        feed_dict={ 
                            self.inputs: batch_images,
                            self.z: batch_z,
                            self.y:batch_labels,
                        })
                    self.writer.add_summary(summary_str, counter)

                    # 更新生成器网络
                    _, summary_str = self.sess.run([g_optim, self.g_sum],
                        feed_dict={
                            self.z: batch_z, 
                            self.y:batch_labels,
                        })
                    self.writer.add_summary(summary_str, counter)

                    # 与论文中所描述的不同，连续运行两次生成器优化方法，让判别器的LOSS不为0
                    _, summary_str = self.sess.run([g_optim, self.g_sum],
                        feed_dict={ self.z: batch_z, self.y:batch_labels })
                    self.writer.add_summary(summary_str, counter)
                    
                    # 计算误差值
                    errD_fake = self.d_loss_fake.eval(session=self.sess, feed_dict={
                            self.z: batch_z, 
                            self.y:batch_labels
                    })
                    errD_real = self.d_loss_real.eval(session=self.sess, feed_dict={
                            self.inputs: batch_images,
                            self.y:batch_labels
                    })
                    errG = self.g_loss.eval(session=self.sess, feed_dict={
                            self.z: batch_z,
                            self.y: batch_labels
                    })
                else:
                    # Update D network
                    _, summary_str = self.sess.run([d_optim, self.d_sum],
                        feed_dict={ self.inputs: batch_images, self.z: batch_z })
                    self.writer.add_summary(summary_str, counter)

                    # Update G network
                    _, summary_str = self.sess.run([g_optim, self.g_sum],
                        feed_dict={ self.z: batch_z })
                    self.writer.add_summary(summary_str, counter)

                    # Run g_optim twice to make sure that d_loss does not go to zero (different from paper)
                    _, summary_str = self.sess.run([g_optim, self.g_sum],
                        feed_dict={ self.z: batch_z })
                    self.writer.add_summary(summary_str, counter)
                    
                    errD_fake = self.d_loss_fake.eval(self.sess, { self.z: batch_z })
                    errD_real = self.d_loss_real.eval(self.sess, { self.inputs: batch_images })
                    errG = self.g_loss.eval({self.z: batch_z})
                    
                counter += 1
                print("Epoch: [%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \
                    % (epoch, idx, batch_idxs,
                        time.time() - start_time, errD_fake+errD_real, errG))
                
                # 每100条数据，打印一条日志，生成一张图片
                if np.mod(counter, 100) == 1:
                    if config.dataset == 'mnist':
                        samples, d_loss, g_loss = self.sess.run(
                            [self.sampler, self.d_loss, self.g_loss],
                            feed_dict={
                                    self.z: sample_z,
                                    self.inputs: sample_inputs,
                                    self.y:sample_labels,
                            }
                        )
                        manifold_h = int(np.ceil(np.sqrt(samples.shape[0])))
                        manifold_w = int(np.floor(np.sqrt(samples.shape[0])))
                        save_images(samples, [manifold_h, manifold_w],
                                    './{}/train_{:02d}_{:04d}.png'.format(config.sample_dir, epoch, idx))
                        print("[Sample] d_loss: %.8f, g_loss: %.8f" % (d_loss, g_loss)) 
                    else:
                        try:
                            samples, d_loss, g_loss = self.sess.run(
                                [self.sampler, self.d_loss, self.g_loss],
                                feed_dict={
                                        self.z: sample_z,
                                        self.inputs: sample_inputs,
                                },
                            )
                            manifold_h = int(np.ceil(np.sqrt(samples.shape[0])))
                            manifold_w = int(np.floor(np.sqrt(samples.shape[0])))
                            save_images(samples, [manifold_h, manifold_w],
                                        './{}/train_{:02d}_{:04d}.png'.format(config.sample_dir, epoch, idx))
                            print("[Sample] d_loss: %.8f, g_loss: %.8f" % (d_loss, g_loss)) 
                        except:
                            print("one pic error!...")
                            
                # 每跑500次保存一次模型节点
                if np.mod(counter, 500) == 2:
                    self.save(config.checkpoint_dir, counter)
                    
    # 判别器
    def discriminator(self, image, y=None, reuse=False):
        with tf.variable_scope("discriminator") as scope:
            if reuse:
                scope.reuse_variables()
            # 三层卷积网络，一般采取的就是这种方法
            if not self.y_dim:
                h0 = lrelu(conv2d(image, self.df_dim, name='d_h0_conv'))
                h1 = lrelu(self.d_bn1(conv2d(h0, self.df_dim*2, name='d_h1_conv')))
                h2 = lrelu(self.d_bn2(conv2d(h1, self.df_dim*4, name='d_h2_conv')))
                h3 = lrelu(self.d_bn3(conv2d(h2, self.df_dim*8, name='d_h3_conv')))
                h4 = linear(tf.reshape(h3, [self.batch_size, -1]), 1, 'd_h3_lin')

                return tf.nn.sigmoid(h4), h4
            else:
                yb = tf.reshape(y, [self.batch_size, 1, 1, self.y_dim])
                x = conv_cond_concat(image, yb)

                h0 = lrelu(conv2d(x, self.c_dim + self.y_dim, name='d_h0_conv'))
                h0 = conv_cond_concat(h0, yb)

                h1 = lrelu(self.d_bn1(conv2d(h0, self.df_dim + self.y_dim, name='d_h1_conv')))
                h1 = tf.reshape(h1, [self.batch_size, -1])      
                h1 = concat([h1, y], 1)
                
                h2 = lrelu(self.d_bn2(linear(h1, self.dfc_dim, 'd_h2_lin')))
                h2 = concat([h2, y], 1)

                h3 = linear(h2, 1, 'd_h3_lin')
                
                return tf.nn.sigmoid(h3), h3
    
    # 生成器        
    def generator(self, z, y=None):
        with tf.variable_scope("generator") as scope:
            if not self.y_dim:
                s_h, s_w = self.output_height, self.output_width
                s_h2, s_w2 = conv_out_size_same(s_h, 2), conv_out_size_same(s_w, 2)
                s_h4, s_w4 = conv_out_size_same(s_h2, 2), conv_out_size_same(s_w2, 2)
                s_h8, s_w8 = conv_out_size_same(s_h4, 2), conv_out_size_same(s_w4, 2)
                s_h16, s_w16 = conv_out_size_same(s_h8, 2), conv_out_size_same(s_w8, 2)

                # project `z` and reshape
                self.z_, self.h0_w, self.h0_b = linear(
                        z, self.gf_dim*8*s_h16*s_w16, 'g_h0_lin', with_w=True)

                self.h0 = tf.reshape(
                        self.z_, [-1, s_h16, s_w16, self.gf_dim * 8])
                h0 = tf.nn.relu(self.g_bn0(self.h0))

                self.h1, self.h1_w, self.h1_b = deconv2d(
                        h0, [self.batch_size, s_h8, s_w8, self.gf_dim*4], name='g_h1', with_w=True)
                h1 = tf.nn.relu(self.g_bn1(self.h1))

                h2, self.h2_w, self.h2_b = deconv2d(
                        h1, [self.batch_size, s_h4, s_w4, self.gf_dim*2], name='g_h2', with_w=True)
                h2 = tf.nn.relu(self.g_bn2(h2))

                h3, self.h3_w, self.h3_b = deconv2d(
                        h2, [self.batch_size, s_h2, s_w2, self.gf_dim*1], name='g_h3', with_w=True)
                h3 = tf.nn.relu(self.g_bn3(h3))

                h4, self.h4_w, self.h4_b = deconv2d(
                        h3, [self.batch_size, s_h, s_w, self.c_dim], name='g_h4', with_w=True)

                return tf.nn.tanh(h4)
            else:
                s_h, s_w = self.output_height, self.output_width
                s_h2, s_h4 = int(s_h/2), int(s_h/4)
                s_w2, s_w4 = int(s_w/2), int(s_w/4)

                # yb = tf.expand_dims(tf.expand_dims(y, 1),2)
                yb = tf.reshape(y, [self.batch_size, 1, 1, self.y_dim])
                z = concat([z, y], 1)

                h0 = tf.nn.relu(
                        self.g_bn0(linear(z, self.gfc_dim, 'g_h0_lin')))
                h0 = concat([h0, y], 1)

                h1 = tf.nn.relu(self.g_bn1(
                        linear(h0, self.gf_dim*2*s_h4*s_w4, 'g_h1_lin')))
                h1 = tf.reshape(h1, [self.batch_size, s_h4, s_w4, self.gf_dim * 2])

                h1 = conv_cond_concat(h1, yb)

                h2 = tf.nn.relu(self.g_bn2(deconv2d(h1,
                        [self.batch_size, s_h2, s_w2, self.gf_dim * 2], name='g_h2')))
                h2 = conv_cond_concat(h2, yb)

                return tf.nn.sigmoid(
                        deconv2d(h2, [self.batch_size, s_h, s_w, self.c_dim], name='g_h3'))
            
    def sampler(self, z, y=None):
        with tf.variable_scope("generator") as scope:
            scope.reuse_variables()

            if not self.y_dim:
                s_h, s_w = self.output_height, self.output_width
                s_h2, s_w2 = conv_out_size_same(s_h, 2), conv_out_size_same(s_w, 2)
                s_h4, s_w4 = conv_out_size_same(s_h2, 2), conv_out_size_same(s_w2, 2)
                s_h8, s_w8 = conv_out_size_same(s_h4, 2), conv_out_size_same(s_w4, 2)
                s_h16, s_w16 = conv_out_size_same(s_h8, 2), conv_out_size_same(s_w8, 2)

                # project `z` and reshape
                h0 = tf.reshape(
                        linear(z, self.gf_dim*8*s_h16*s_w16, 'g_h0_lin'),
                        [-1, s_h16, s_w16, self.gf_dim * 8])
                h0 = tf.nn.relu(self.g_bn0(h0, train=False))

                h1 = deconv2d(h0, [self.batch_size, s_h8, s_w8, self.gf_dim*4], name='g_h1')
                h1 = tf.nn.relu(self.g_bn1(h1, train=False))

                h2 = deconv2d(h1, [self.batch_size, s_h4, s_w4, self.gf_dim*2], name='g_h2')
                h2 = tf.nn.relu(self.g_bn2(h2, train=False))

                h3 = deconv2d(h2, [self.batch_size, s_h2, s_w2, self.gf_dim*1], name='g_h3')
                h3 = tf.nn.relu(self.g_bn3(h3, train=False))

                h4 = deconv2d(h3, [self.batch_size, s_h, s_w, self.c_dim], name='g_h4')

                return tf.nn.tanh(h4)
            else:
                s_h, s_w = self.output_height, self.output_width
                s_h2, s_h4 = int(s_h/2), int(s_h/4)
                s_w2, s_w4 = int(s_w/2), int(s_w/4)

                # yb = tf.reshape(y, [-1, 1, 1, self.y_dim])
                yb = tf.reshape(y, [self.batch_size, 1, 1, self.y_dim])
                z = concat([z, y], 1)

                h0 = tf.nn.relu(self.g_bn0(linear(z, self.gfc_dim, 'g_h0_lin'), train=False))
                h0 = concat([h0, y], 1)

                h1 = tf.nn.relu(self.g_bn1(
                        linear(h0, self.gf_dim*2*s_h4*s_w4, 'g_h1_lin'), train=False))
                h1 = tf.reshape(h1, [self.batch_size, s_h4, s_w4, self.gf_dim * 2])
                h1 = conv_cond_concat(h1, yb)

                h2 = tf.nn.relu(self.g_bn2(
                        deconv2d(h1, [self.batch_size, s_h2, s_w2, self.gf_dim * 2], name='g_h2'), train=False))
                h2 = conv_cond_concat(h2, yb)

                return tf.nn.sigmoid(deconv2d(h2, [self.batch_size, s_h, s_w, self.c_dim], name='g_h3'))
            
    def load_mnist(self):
        data_dir = os.path.join("./data", self.dataset_name)
        
        fd = open(os.path.join(data_dir,'train-images-idx3-ubyte'))
        loaded = np.fromfile(file=fd,dtype=np.uint8)
        trX = loaded[16:].reshape((60000,28,28,1)).astype(np.float)

        fd = open(os.path.join(data_dir,'train-labels-idx1-ubyte'))
        loaded = np.fromfile(file=fd,dtype=np.uint8)
        trY = loaded[8:].reshape((60000)).astype(np.float)

        fd = open(os.path.join(data_dir,'t10k-images-idx3-ubyte'))
        loaded = np.fromfile(file=fd,dtype=np.uint8)
        teX = loaded[16:].reshape((10000,28,28,1)).astype(np.float)

        fd = open(os.path.join(data_dir,'t10k-labels-idx1-ubyte'))
        loaded = np.fromfile(file=fd,dtype=np.uint8)
        teY = loaded[8:].reshape((10000)).astype(np.float)

        trY = np.asarray(trY)
        teY = np.asarray(teY)
        
        X = np.concatenate((trX, teX), axis=0)
        y = np.concatenate((trY, teY), axis=0).astype(np.int)
        
        seed = 547
        np.random.seed(seed)
        np.random.shuffle(X)
        np.random.seed(seed)
        np.random.shuffle(y)
        
        y_vec = np.zeros((len(y), self.y_dim), dtype=np.float)
        for i, label in enumerate(y):
            y_vec[i,y[i]] = 1.0
        
        return X/255.,y_vec

    @property
    def model_dir(self):
        return "{}_{}_{}_{}".format(
                self.dataset_name, self.batch_size,
                self.output_height, self.output_width)
            
    def save(self, checkpoint_dir, step):
        model_name = "DCGAN.model"
        checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)

        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)

        self.saver.save(self.sess,
                        os.path.join(checkpoint_dir, model_name),
                        global_step=step)

    def load(self, checkpoint_dir):
        import re
        print(" [*] Reading checkpoints...")
        checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)

        ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
            self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
            counter = int(next(re.finditer("(\d+)(?!.*\d)",ckpt_name)).group(0))
            print(" [*] Success to read {}".format(ckpt_name))
            return True, counter
        else:
            print(" [*] Failed to find a checkpoint")
            return False, 0




<h3>设置训练参数

In [19]:
if 'already_defined' not in locals().keys() :
    already_defined = True
    flags = tf.app.flags
    flags.DEFINE_integer("epoch", 25, "Epoch to train [25]")
    flags.DEFINE_float("learning_rate", 0.0002, "Learning rate of for adam [0.0002]")
    flags.DEFINE_float("beta1", 0.5, "Momentum term of adam [0.5]")
    flags.DEFINE_integer("train_size", np.inf, "The size of train images [np.inf]")
    flags.DEFINE_integer("batch_size", 64, "The size of batch images [64]")
    flags.DEFINE_integer("input_height", 28, "The size of image to use (will be center cropped). [108]")
    flags.DEFINE_integer("input_width", None, "The size of image to use (will be center cropped). If None, same value as input_height [None]")
    flags.DEFINE_integer("output_height", 28, "The size of the output images to produce [64]")
    flags.DEFINE_integer("output_width", None, "The size of the output images to produce. If None, same value as output_height [None]")
    flags.DEFINE_integer("c_dim", 1, "Dimension of image color. [3]")
    flags.DEFINE_string("dataset", "mnist", "The name of dataset [celebA, mnist, lsun]")
    flags.DEFINE_string("input_fname_pattern", "*.jpg", "Glob pattern of filename of input images [*]")
    flags.DEFINE_string("checkpoint_dir", "checkpoint", "Directory name to save the checkpoints [checkpoint]")
    flags.DEFINE_string("sample_dir", "samples", "Directory name to save the image samples [samples]")
    flags.DEFINE_boolean("is_train", True, "True for training, False for testing [False]")
    flags.DEFINE_boolean("is_crop", False, "True for training, False for testing [False]")
    flags.DEFINE_boolean("visualize", True, "True for visualizing, False for nothing [False]")
    FLAGS = flags.FLAGS
    

In [20]:
def test():
    pp.pprint(flags.FLAGS.__flags)
    
    if FLAGS.input_width is None:
        FLAGS.input_width = FLAGS.input_height
    if FLAGS.output_width is None:
        FLAGS.output_width = FLAGS.output_height
    
    # 训练过程中的模型保存在checkpoint文件夹中
    # 生成的样本图片保存在sample文件夹中
    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.sample_dir):
        os.makedirs(FLAGS.sample_dir)
        
    run_config = tf.ConfigProto()
    run_config.gpu_options.allow_growth=True
    
    graph = tf.Graph().as_default()
    tf.reset_default_graph()
    sess = tf.Session(config=run_config)
    sess.as_default()
    
    #with tf.Session(config=run_config) as sess:
    #sess = tf.Session(config=run_config)
    
    if FLAGS.dataset == 'mnist':
        dcgan = DCGAN(
          sess,
          input_width=FLAGS.input_width,
          input_height=FLAGS.input_height,
          output_width=FLAGS.output_width,
          output_height=FLAGS.output_height,
          batch_size=FLAGS.batch_size,
          sample_num=FLAGS.batch_size,
          y_dim=10,
          c_dim=1,
          dataset_name=FLAGS.dataset,
          input_fname_pattern=FLAGS.input_fname_pattern,
          is_crop=FLAGS.is_crop,
          checkpoint_dir=FLAGS.checkpoint_dir,
          sample_dir=FLAGS.sample_dir)
    else:
        dcgan = DCGAN(
            sess,
            input_width=FLAGS.input_width,
            input_height=FLAGS.input_height,
            output_width=FLAGS.output_width,
            output_height=FLAGS.output_height,
            batch_size=FLAGS.batch_size,
            sample_num=FLAGS.batch_size,
            c_dim=FLAGS.c_dim,
            dataset_name=FLAGS.dataset,
            input_fname_pattern=FLAGS.input_fname_pattern,
            is_crop=FLAGS.is_crop,
            checkpoint_dir=FLAGS.checkpoint_dir,
            sample_dir=FLAGS.sample_dir)

    show_all_variables()
    if FLAGS.is_train:
        dcgan.train(FLAGS)
    else:
        if not dcgan.load(FLAGS.checkpoint_dir):
            raise Exception("[!] Train a model first, then run test mode")
            
    # Below is codes for visualization
    OPTION = 1
    visualize(sess, dcgan, FLAGS, OPTION)
    
    sess.run()

In [21]:
test()

{'batch_size': 64,
 'beta1': 0.5,
 'c_dim': 1,
 'checkpoint_dir': 'checkpoint',
 'dataset': 'mnist',
 'epoch': 25,
 'input_fname_pattern': '*.jpg',
 'input_height': 28,
 'input_width': 28,
 'is_crop': False,
 'is_train': True,
 'learning_rate': 0.0002,
 'output_height': 28,
 'output_width': 28,
 'sample_dir': 'samples',
 'train_size': inf,
 'visualize': True}
---------
Variables: name (type shape) [size]
---------
generator/g_h0_lin/Matrix:0 (float32_ref 110x1024) [112640, bytes: 450560]
generator/g_h0_lin/bias:0 (float32_ref 1024) [1024, bytes: 4096]
generator/g_bn0/beta:0 (float32_ref 1024) [1024, bytes: 4096]
generator/g_bn0/gamma:0 (float32_ref 1024) [1024, bytes: 4096]
generator/g_h1_lin/Matrix:0 (float32_ref 1034x6272) [6485248, bytes: 25940992]
generator/g_h1_lin/bias:0 (float32_ref 6272) [6272, bytes: 25088]
generator/g_bn1/beta:0 (float32_ref 6272) [6272, bytes: 25088]
generator/g_bn1/gamma:0 (float32_ref 6272) [6272, bytes: 25088]
generator/g_h2/w:0 (float32_ref 5x5x128x138) 

KeyboardInterrupt: 