In [1]:
from models import *
from wgan import *
import sys
sys.path.append('../')
from getData import *

# Wasserstein GAN

## Make dataset

In [32]:
x, y  = get_mnist()

(60000, 28, 28, 1) (60000, 10)


## Make generator

In [18]:
generator = build_generator( 100
                           , generator_initial_dense_layer_size = (7, 7, 128)
                           , generator_upsample = [1, 2, 2]
                           , generator_conv_filters = [128,64,1]
                           , generator_conv_kernel_size = [5,5,5]
                           , generator_conv_strides = [1,1,1]
                           , generator_batch_norm_momentum = 0.8
                           , generator_dropout_rate = None
                           , generator_weight_init = RandomNormal(mean=0., stddev=0.02)
                           )
generator.summary()

Model: "model_8"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
generator_input (InputLayer) [(None, 100)]             0         
_________________________________________________________________
dense_7 (Dense)              (None, 6272)              633472    
_________________________________________________________________
batch_normalization_20 (Batc (None, 6272)              25088     
_________________________________________________________________
leaky_re_lu_24 (LeakyReLU)   (None, 6272)              0         
_________________________________________________________________
reshape_6 (Reshape)          (None, 7, 7, 128)         0         
_________________________________________________________________
generator_conv_0 (Conv2DTran (None, 7, 7, 128)         409728    
_________________________________________________________________
batch_normalization_21 (Batc (None, 7, 7, 128)         512 

## Make Critic

* EM distance 이용해 확률 분포 측정
* 기존 KL divergence 보다 더 좋은 결과를 낼 수 있다. 
* 이를 이용해 training 시 나타나는 G-D 간의 rl balance를 덜 신경써도 무방
* GAN에서 일반적으로 발생하는 mode dropping 해결 가능

In [19]:
critic = build_critic((28,28,1)
                    , critic_conv_filters = [32,64,128,128]
                    , critic_conv_kernel_size = [5,5,5,5]
                    , critic_conv_strides = [2,2,2,1]
                    , critic_batch_norm_momentum = None
                    , critic_activation = 'leaky_relu'
                    , critic_dropout_rate = None
                    , critic_weight_init = RandomNormal(mean=0., stddev=0.02)
                    )
#critic.summary()

## Make WGAN

In [27]:
RUN_FOLDER='./results/mnist_2'
wgan = WGAN(generator, critic, RUN_FOLDER)

### compile

In [28]:
wgan.compile( optimizer = 'rmsprop'
            , generator_lr=0.00005
            , critic_lr=0.00005
            )
wgan.model.summary()

Model: "model_11"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
model_input (InputLayer)     [(None, 100)]             0         
_________________________________________________________________
model_8 (Model)              (None, 28, 28, 1)         1275521   
_________________________________________________________________
model_9 (Model)              (None, 1)                 668801    
Total params: 1,275,521
Trainable params: 1,262,593
Non-trainable params: 12,928
_________________________________________________________________


### plot model

In [29]:
wgan.plot_model()

### load weights ( if youn need )

In [30]:
wgan.load_weights('results/mnist_1/weights/weights.h5')

### train

In [31]:
wgan.train(     
    x
    , batch_size = 64
    , epochs = 500
    , print_every_n_batches = 5
    , n_critic = 5
    , clip_threshold = 0.01
)

0 [D loss: (-0.007)(R -0.210, F 0.196)]  [G loss: -0.161] 
1 [D loss: (-0.008)(R -0.208, F 0.193)]  [G loss: -0.167] 
2 [D loss: (-0.006)(R -0.216, F 0.203)]  [G loss: -0.178] 
3 [D loss: (-0.002)(R -0.215, F 0.211)]  [G loss: -0.182] 
4 [D loss: (-0.002)(R -0.213, F 0.208)]  [G loss: -0.182] 
5 [D loss: (-0.003)(R -0.216, F 0.210)]  [G loss: -0.184] 
6 [D loss: (-0.003)(R -0.220, F 0.213)]  [G loss: -0.188] 
7 [D loss: (-0.003)(R -0.220, F 0.215)]  [G loss: -0.188] 
8 [D loss: (-0.005)(R -0.222, F 0.212)]  [G loss: -0.185] 
9 [D loss: (-0.003)(R -0.218, F 0.212)]  [G loss: -0.187] 
10 [D loss: (-0.002)(R -0.215, F 0.211)]  [G loss: -0.184] 
11 [D loss: (-0.001)(R -0.218, F 0.217)]  [G loss: -0.184] 
12 [D loss: (-0.002)(R -0.216, F 0.212)]  [G loss: -0.187] 
13 [D loss: (-0.002)(R -0.214, F 0.210)]  [G loss: -0.179] 
14 [D loss: (-0.000)(R -0.208, F 0.207)]  [G loss: -0.178] 
15 [D loss: (0.001)(R -0.200, F 0.202)]  [G loss: -0.175] 
16 [D loss: (-0.001)(R -0.204, F 0.203)]  [G loss: 

137 [D loss: (-0.033)(R -0.227, F 0.160)]  [G loss: -0.125] 
138 [D loss: (-0.039)(R -0.239, F 0.162)]  [G loss: -0.130] 
139 [D loss: (-0.032)(R -0.224, F 0.160)]  [G loss: -0.117] 
140 [D loss: (-0.034)(R -0.233, F 0.165)]  [G loss: -0.131] 
141 [D loss: (-0.044)(R -0.254, F 0.166)]  [G loss: -0.116] 
142 [D loss: (-0.035)(R -0.227, F 0.158)]  [G loss: -0.132] 
143 [D loss: (-0.040)(R -0.246, F 0.167)]  [G loss: -0.131] 
144 [D loss: (-0.040)(R -0.247, F 0.167)]  [G loss: -0.126] 
145 [D loss: (-0.039)(R -0.242, F 0.164)]  [G loss: -0.129] 
146 [D loss: (-0.043)(R -0.254, F 0.168)]  [G loss: -0.124] 
147 [D loss: (-0.046)(R -0.257, F 0.165)]  [G loss: -0.134] 
148 [D loss: (-0.050)(R -0.269, F 0.170)]  [G loss: -0.136] 
149 [D loss: (-0.038)(R -0.250, F 0.175)]  [G loss: -0.144] 
150 [D loss: (-0.053)(R -0.263, F 0.156)]  [G loss: -0.131] 
151 [D loss: (-0.058)(R -0.299, F 0.182)]  [G loss: -0.146] 
152 [D loss: (-0.039)(R -0.260, F 0.181)]  [G loss: -0.128] 
153 [D loss: (-0.038)(R 

272 [D loss: (-0.068)(R -0.360, F 0.224)]  [G loss: -0.164] 
273 [D loss: (-0.091)(R -0.404, F 0.222)]  [G loss: -0.141] 
274 [D loss: (-0.072)(R -0.382, F 0.238)]  [G loss: -0.153] 
275 [D loss: (-0.079)(R -0.376, F 0.217)]  [G loss: -0.176] 
276 [D loss: (-0.080)(R -0.385, F 0.225)]  [G loss: -0.152] 
277 [D loss: (-0.077)(R -0.388, F 0.234)]  [G loss: -0.151] 
278 [D loss: (-0.081)(R -0.384, F 0.221)]  [G loss: -0.143] 
279 [D loss: (-0.083)(R -0.345, F 0.178)]  [G loss: -0.145] 
280 [D loss: (-0.077)(R -0.408, F 0.254)]  [G loss: -0.169] 
281 [D loss: (-0.074)(R -0.388, F 0.240)]  [G loss: -0.171] 
282 [D loss: (-0.061)(R -0.369, F 0.246)]  [G loss: -0.159] 
283 [D loss: (-0.071)(R -0.375, F 0.233)]  [G loss: -0.147] 
284 [D loss: (-0.058)(R -0.377, F 0.261)]  [G loss: -0.161] 
285 [D loss: (-0.078)(R -0.414, F 0.257)]  [G loss: -0.159] 
286 [D loss: (-0.073)(R -0.387, F 0.241)]  [G loss: -0.185] 
287 [D loss: (-0.072)(R -0.397, F 0.254)]  [G loss: -0.169] 
288 [D loss: (-0.086)(R 

407 [D loss: (-0.105)(R -0.533, F 0.323)]  [G loss: -0.215] 
408 [D loss: (-0.086)(R -0.502, F 0.331)]  [G loss: -0.229] 
409 [D loss: (-0.109)(R -0.512, F 0.295)]  [G loss: -0.204] 
410 [D loss: (-0.094)(R -0.504, F 0.315)]  [G loss: -0.198] 
411 [D loss: (-0.099)(R -0.515, F 0.317)]  [G loss: -0.205] 
412 [D loss: (-0.106)(R -0.540, F 0.328)]  [G loss: -0.218] 
413 [D loss: (-0.047)(R -0.443, F 0.349)]  [G loss: -0.258] 
414 [D loss: (-0.108)(R -0.534, F 0.318)]  [G loss: -0.223] 
415 [D loss: (-0.094)(R -0.523, F 0.335)]  [G loss: -0.215] 
416 [D loss: (-0.075)(R -0.461, F 0.311)]  [G loss: -0.240] 
417 [D loss: (-0.101)(R -0.519, F 0.317)]  [G loss: -0.230] 
418 [D loss: (-0.082)(R -0.482, F 0.318)]  [G loss: -0.221] 
419 [D loss: (-0.102)(R -0.536, F 0.332)]  [G loss: -0.217] 
420 [D loss: (-0.066)(R -0.477, F 0.346)]  [G loss: -0.209] 
421 [D loss: (-0.062)(R -0.469, F 0.344)]  [G loss: -0.203] 
422 [D loss: (-0.079)(R -0.479, F 0.321)]  [G loss: -0.201] 
423 [D loss: (-0.113)(R 