In [1]:
import os
import scipy.misc
import numpy as np

from model import DCGAN
from utils import pp, visualize, to_json, show_all_variables

import tensorflow as tf

In [None]:
flags = tf.app.flags
flags.DEFINE_integer("epoch", 25, "Epoch to train [25]")
flags.DEFINE_float("learning_rate", 0.0002, "Learning rate of for adam [0.0002]")
flags.DEFINE_float("beta1", 0.5, "Momentum term of adam [0.5]")
flags.DEFINE_integer("train_size", np.inf, "The size of train images [np.inf]")
flags.DEFINE_integer("batch_size", 64, "The size of batch images [64]")
#flags.DEFINE_integer("input_height", 108, "The size of image to use (will be center cropped). [108]")
flags.DEFINE_integer("input_height", 28, "The size of image to use (will be center cropped). [108]")
flags.DEFINE_integer("input_width", None, "The size of image to use (will be center cropped). If None, same value as input_height [None]")
#flags.DEFINE_integer("output_height", 64, "The size of the output images to produce [64]")
flags.DEFINE_integer("output_height", 28, "The size of the output images to produce [64]")
flags.DEFINE_integer("output_width", None, "The size of the output images to produce. If None, same value as output_height [None]")
#flags.DEFINE_string("dataset", "celebA", "The name of dataset [celebA, mnist, lsun]")
flags.DEFINE_string("dataset", "mnist", "The name of dataset [celebA, mnist, lsun]")
flags.DEFINE_string("input_fname_pattern", "*.jpg", "Glob pattern of filename of input images [*]")
flags.DEFINE_string("checkpoint_dir", "checkpoint", "Directory name to save the checkpoints [checkpoint]")
flags.DEFINE_string("sample_dir", "samples", "Directory name to save the image samples [samples]")
#flags.DEFINE_boolean("train", False, "True for training, False for testing [False]")
flags.DEFINE_boolean("train", True, "True for training, False for testing [False]")
flags.DEFINE_boolean("crop", False, "True for training, False for testing [False]")
flags.DEFINE_boolean("visualize", False, "True for visualizing, False for nothing [False]")
flags.DEFINE_integer("generate_test_images", 100, "Number of images to generate during test. [100]")
flags.DEFINE_integer("loss_type", 2, "Loss type [0=cross entropy] 1=logloss 2=wasserstein")
FLAGS = flags.FLAGS

In [None]:
def main(_):
  pp.pprint(flags.FLAGS.__flags)

  if FLAGS.input_width is None:
    FLAGS.input_width = FLAGS.input_height
  if FLAGS.output_width is None:
    FLAGS.output_width = FLAGS.output_height

  if not os.path.exists(FLAGS.checkpoint_dir):
    os.makedirs(FLAGS.checkpoint_dir)
  if not os.path.exists(FLAGS.sample_dir):
    os.makedirs(FLAGS.sample_dir)

  #gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333)
  run_config = tf.ConfigProto()
  run_config.gpu_options.allow_growth=True

  with tf.Session(config=run_config) as sess:
    if FLAGS.dataset == 'mnist':
      dcgan = DCGAN(
          sess,
          input_width=FLAGS.input_width,
          input_height=FLAGS.input_height,
          output_width=FLAGS.output_width,
          output_height=FLAGS.output_height,
          batch_size=FLAGS.batch_size,
          sample_num=FLAGS.batch_size,
          y_dim=10,
          z_dim=FLAGS.generate_test_images,
          dataset_name=FLAGS.dataset,
          input_fname_pattern=FLAGS.input_fname_pattern,
          crop=FLAGS.crop,
          checkpoint_dir=FLAGS.checkpoint_dir,
          sample_dir=FLAGS.sample_dir)
    else:
      dcgan = DCGAN(
          sess,
          input_width=FLAGS.input_width,
          input_height=FLAGS.input_height,
          output_width=FLAGS.output_width,
          output_height=FLAGS.output_height,
          batch_size=FLAGS.batch_size,
          sample_num=FLAGS.batch_size,
          z_dim=FLAGS.generate_test_images,
          dataset_name=FLAGS.dataset,
          input_fname_pattern=FLAGS.input_fname_pattern,
          crop=FLAGS.crop,
          checkpoint_dir=FLAGS.checkpoint_dir,
          sample_dir=FLAGS.sample_dir)

    show_all_variables()

    if FLAGS.train:
      dcgan.train(FLAGS)
    else:
      if not dcgan.load(FLAGS.checkpoint_dir)[0]:
        raise Exception("[!] Train a model first, then run test mode")
      

    # to_json("./web/js/layers.js", [dcgan.h0_w, dcgan.h0_b, dcgan.g_bn0],
    #                 [dcgan.h1_w, dcgan.h1_b, dcgan.g_bn1],
    #                 [dcgan.h2_w, dcgan.h2_b, dcgan.g_bn2],
    #                 [dcgan.h3_w, dcgan.h3_b, dcgan.g_bn3],
    #                 [dcgan.h4_w, dcgan.h4_b, None])

    # Below is codes for visualization
    OPTION = 1
    visualize(sess, dcgan, FLAGS, OPTION)

if __name__ == '__main__':
  tf.app.run()

{'batch_size': 64,
 'beta1': 0.5,
 'checkpoint_dir': 'checkpoint',
 'crop': False,
 'dataset': 'mnist',
 'epoch': 25,
 'generate_test_images': 100,
 'input_fname_pattern': '*.jpg',
 'input_height': 28,
 'input_width': None,
 'learning_rate': 0.0002,
 'output_height': 28,
 'output_width': None,
 'sample_dir': 'samples',
 'train': True,
 'train_size': inf,
 'visualize': False}
---------
Variables: name (type shape) [size]
---------
generator/g_h0_lin/Matrix:0 (float32_ref 110x1024) [112640, bytes: 450560]
generator/g_h0_lin/bias:0 (float32_ref 1024) [1024, bytes: 4096]
generator/g_bn0/beta:0 (float32_ref 1024) [1024, bytes: 4096]
generator/g_bn0/gamma:0 (float32_ref 1024) [1024, bytes: 4096]
generator/g_h1_lin/Matrix:0 (float32_ref 1034x6272) [6485248, bytes: 25940992]
generator/g_h1_lin/bias:0 (float32_ref 6272) [6272, bytes: 25088]
generator/g_bn1/beta:0 (float32_ref 6272) [6272, bytes: 25088]
generator/g_bn1/gamma:0 (float32_ref 6272) [6272, bytes: 25088]
generator/g_h2/w:0 (float32_r

Epoch: [ 0] [  70/1093] time: 87.9636, d_loss: -0.00488281, g_loss: -2215.37451172
Epoch: [ 0] [  71/1093] time: 89.1325, d_loss: -0.00610352, g_loss: -2222.34545898
Epoch: [ 0] [  72/1093] time: 90.3048, d_loss: -0.00708008, g_loss: -2228.19604492
Epoch: [ 0] [  73/1093] time: 91.4649, d_loss: -0.00634766, g_loss: -2218.78271484
Epoch: [ 0] [  74/1093] time: 92.6986, d_loss: -0.01123047, g_loss: -2224.25830078
Epoch: [ 0] [  75/1093] time: 93.9438, d_loss: -0.01123047, g_loss: -2223.73779297
Epoch: [ 0] [  76/1093] time: 95.1690, d_loss: -0.01367188, g_loss: -2224.57861328
Epoch: [ 0] [  77/1093] time: 96.5516, d_loss: -0.01269531, g_loss: -2226.05761719
Epoch: [ 0] [  78/1093] time: 97.7484, d_loss: -0.01220703, g_loss: -2224.92773438
Epoch: [ 0] [  79/1093] time: 99.0932, d_loss: -0.01513672, g_loss: -2221.59863281
Epoch: [ 0] [  80/1093] time: 100.4710, d_loss: -0.01489258, g_loss: -2222.02709961
Epoch: [ 0] [  81/1093] time: 101.6999, d_loss: -0.01538086, g_loss: -2219.89111328
Ep

Epoch: [ 0] [ 168/1093] time: 218.4017, d_loss: 0.02368164, g_loss: -2224.05639648
Epoch: [ 0] [ 169/1093] time: 219.5937, d_loss: 0.01928711, g_loss: -2214.00732422
Epoch: [ 0] [ 170/1093] time: 220.9738, d_loss: 0.01757812, g_loss: -2219.13916016
Epoch: [ 0] [ 171/1093] time: 222.3625, d_loss: 0.00488281, g_loss: -2215.26367188
Epoch: [ 0] [ 172/1093] time: 223.6812, d_loss: 0.01416016, g_loss: -2218.24511719
Epoch: [ 0] [ 173/1093] time: 224.8951, d_loss: 0.00561523, g_loss: -2212.90185547
Epoch: [ 0] [ 174/1093] time: 226.1578, d_loss: 0.01757812, g_loss: -2214.77099609
Epoch: [ 0] [ 175/1093] time: 227.5044, d_loss: 0.00292969, g_loss: -2210.82983398
Epoch: [ 0] [ 176/1093] time: 228.6908, d_loss: -0.02343750, g_loss: -2220.27294922
Epoch: [ 0] [ 177/1093] time: 229.8473, d_loss: -0.01611328, g_loss: -2218.48046875
Epoch: [ 0] [ 178/1093] time: 231.0021, d_loss: -0.00878906, g_loss: -2211.39208984
Epoch: [ 0] [ 179/1093] time: 232.1535, d_loss: -0.01416016, g_loss: -2220.45507812


Epoch: [ 0] [ 266/1093] time: 336.5356, d_loss: -0.03442383, g_loss: -2205.50024414
Epoch: [ 0] [ 267/1093] time: 337.6761, d_loss: -0.04345703, g_loss: -2216.62939453
Epoch: [ 0] [ 268/1093] time: 338.8312, d_loss: -0.03051758, g_loss: -2211.31445312
Epoch: [ 0] [ 269/1093] time: 339.9964, d_loss: -0.01367188, g_loss: -2197.66943359
Epoch: [ 0] [ 270/1093] time: 341.1665, d_loss: -0.00878906, g_loss: -2217.24755859
Epoch: [ 0] [ 271/1093] time: 342.3267, d_loss: -0.00219727, g_loss: -2215.98315430
Epoch: [ 0] [ 272/1093] time: 343.4801, d_loss: -0.00268555, g_loss: -2208.71533203
Epoch: [ 0] [ 273/1093] time: 344.6780, d_loss: -0.01293945, g_loss: -2207.70629883
Epoch: [ 0] [ 274/1093] time: 345.8305, d_loss: 0.00732422, g_loss: -2197.89208984
Epoch: [ 0] [ 275/1093] time: 346.9732, d_loss: 0.00170898, g_loss: -2204.46655273
Epoch: [ 0] [ 276/1093] time: 348.1257, d_loss: 0.01000977, g_loss: -2211.01342773
Epoch: [ 0] [ 277/1093] time: 349.2667, d_loss: 0.01147461, g_loss: -2199.08374

Epoch: [ 0] [ 364/1093] time: 455.7802, d_loss: -0.06201172, g_loss: -2212.21777344
Epoch: [ 0] [ 365/1093] time: 456.9381, d_loss: -0.06005859, g_loss: -2207.75927734
Epoch: [ 0] [ 366/1093] time: 458.1391, d_loss: -0.08862305, g_loss: -2204.62451172
Epoch: [ 0] [ 367/1093] time: 459.5006, d_loss: -0.07104492, g_loss: -2206.97680664
Epoch: [ 0] [ 368/1093] time: 460.7147, d_loss: -0.05273438, g_loss: -2196.28735352
Epoch: [ 0] [ 369/1093] time: 461.9788, d_loss: -0.09960938, g_loss: -2200.24121094
Epoch: [ 0] [ 370/1093] time: 463.1916, d_loss: -0.07495117, g_loss: -2208.74291992
Epoch: [ 0] [ 371/1093] time: 464.3461, d_loss: -0.03442383, g_loss: -2208.79882812
Epoch: [ 0] [ 372/1093] time: 465.6592, d_loss: -0.03540039, g_loss: -2204.59472656
Epoch: [ 0] [ 373/1093] time: 467.0968, d_loss: -0.03417969, g_loss: -2206.73803711
Epoch: [ 0] [ 374/1093] time: 468.3270, d_loss: -0.04223633, g_loss: -2208.12377930
Epoch: [ 0] [ 375/1093] time: 469.6481, d_loss: -0.05712891, g_loss: -2197.4

Epoch: [ 0] [ 462/1093] time: 573.6374, d_loss: -0.05371094, g_loss: -2205.69970703
Epoch: [ 0] [ 463/1093] time: 574.8013, d_loss: -0.06762695, g_loss: -2204.03588867
Epoch: [ 0] [ 464/1093] time: 575.9702, d_loss: -0.03881836, g_loss: -2198.11401367
Epoch: [ 0] [ 465/1093] time: 577.1300, d_loss: -0.02978516, g_loss: -2197.67041016
Epoch: [ 0] [ 466/1093] time: 578.2846, d_loss: -0.03417969, g_loss: -2186.40869141
Epoch: [ 0] [ 467/1093] time: 579.4286, d_loss: -0.02807617, g_loss: -2200.46020508
Epoch: [ 0] [ 468/1093] time: 580.5990, d_loss: -0.02148438, g_loss: -2197.94287109
Epoch: [ 0] [ 469/1093] time: 581.8113, d_loss: -0.01147461, g_loss: -2186.49462891
Epoch: [ 0] [ 470/1093] time: 582.9781, d_loss: -0.03662109, g_loss: -2187.58984375
Epoch: [ 0] [ 471/1093] time: 584.1424, d_loss: -0.02221680, g_loss: -2202.60693359
Epoch: [ 0] [ 472/1093] time: 585.3090, d_loss: -0.03857422, g_loss: -2198.50219727
Epoch: [ 0] [ 473/1093] time: 586.4723, d_loss: 0.00537109, g_loss: -2206.63