In [1]:
import mxnet as mx
from mxnet import nd, autograd, gluon
import numpy as np

gpu_ctx = mx.gpu(0)
cpu_ctx = mx.cpu(0)

  import OpenSSL.SSL


# Load Dataset

In [2]:
mnist = mx.test_utils.get_mnist()

# Define Encoder

In [33]:
FILTERS = 64
HIDDEN_DIM = 128
BATCH_SIZE = 100
IMG_CHNS = 1
class Encoder(gluon.Block):
    def __init__(self, FILTERS, HIDDEN_DIM, **kwarg):
        super(Encoder, self).__init__(**kwarg)
        with self.name_scope():
            # in_channel: 나중에 판단
            # layout: NCHW
            # 1 * 28 * 28
            self.conv2D_1 = gluon.nn.Conv2D(channels = 1, strides = (1, 1), kernel_size = (3, 3), padding = (1, 1), activation = 'relu')
            # 64 * 28 * 28
            self.conv2D_2 = gluon.nn.Conv2D(channels = FILTERS, strides =(1, 1), kernel_size = (3, 3), padding = (1, 1), activation = 'relu')
            # 128 * 14 * 14
            self.conv2D_3 = gluon.nn.Conv2D(channels = 2 * FILTERS, strides = (2, 2), kernel_size = (3, 3), padding= (1, 1), activation = 'relu')
            # 256 * 7 * 7
            self.conv2D_4 = gluon.nn.Conv2D(channels = 4 * FILTERS, strides = (2, 2), kernel_size = (3, 3), padding = (1, 1), activation = 'relu')
            self.flatten = gluon.nn.Flatten()
            self.Dense_1 = gluon.nn.Dense(units = HIDDEN_DIM, activation = 'relu')
            
    def forward(self, x):
        print("input shape : {}".format(x.shape))
        x = self.conv2D_1(x)
        print("conv2D_1 shape : {}".format(x.shape))
        x = self.conv2D_2(x)
        print("conv2D_2 shape : {}".format(x.shape))
        x = self.conv2D_3(x)
        print("conv2D_3 shape : {}".format(x.shape))
        x = self.conv2D_4(x)
        print("conv2D_4 shape : {}".format(x.shape))
        x = self.flatten(x)
        print("flatten shape : {}".format(x.shape))
        x = self.Dense_1(x)
        print("Dense_1 shape : {}".format(x.shape))
        return x

## check encoder dimensions

In [42]:
enc = Encoder(FILTERS, HIDDEN_DIM)
print(enc)
enc.collect_params().initialize(mx.init.Xavier(magnitude=2.24), ctx = gpu_ctx)
test_input = nd.random_normal(0,1,(1, 1, 28, 28), ctx = gpu_ctx)
enc(test_input)

Encoder(
  (conv2D_1): Conv2D(None -> 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2D_2): Conv2D(None -> 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2D_3): Conv2D(None -> 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (conv2D_4): Conv2D(None -> 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (flatten): Flatten
  (Dense_1): Dense(None -> 128, Activation(relu))
)
input shape : (1, 1, 28, 28)
conv2D_1 shape : (1, 1, 28, 28)
conv2D_2 shape : (1, 64, 28, 28)
conv2D_3 shape : (1, 128, 14, 14)
conv2D_4 shape : (1, 256, 7, 7)
flatten shape : (1, 12544)
Dense_1 shape : (1, 128)



[[ 0.03312962  0.          0.01416666  0.02533435  0.          0.          0.
   0.00735458  0.          0.          0.          0.00912265  0.          0.
   0.00049314  0.          0.0063733   0.          0.          0.01172294
   0.01841418  0.          0.00185168  0.          0.          0.
   0.01338512  0.00092813  0.01868242  0.03296281  0.01281815  0.05120894
   0.          0.00892462  0.00820222  0.          0.          0.          0.
   0.00605315  0.          0.00330101  0.00632412  0.          0.
   0.00889342  0.02030859  0.          0.02896554  0.          0.01615953
   0.02677456  0.0077883   0.          0.          0.02658038  0.          0.
   0.          0.          0.02241181  0.02026091  0.          0.          0.
   0.00331622  0.          0.01012891  0.          0.          0.0179784
   0.          0.          0.          0.00567326  0.          0.0094079
   0.          0.0038055   0.          0.0070935   0.00430711  0.01426433
   0.          0.00962449  0.010971

# Define Decoder

In [62]:
class Decoder(gluon.Block):
    def __init__(self, HIDDEN_DIM, FILTERS, **kwargs):
        super(Decoder, self).__init__(**kwargs)
        self.dense_h = gluon.nn.Dense(units = HIDDEN_DIM, activation = 'relu')
        self.dense_up = gluon.nn.Dense(units = 7 * 7 * FILTERS * 4, activation = 'relu')

        # 256 * 7 * 7
        self.conv_trans_1 = gluon.nn.Conv2DTranspose(channels = FILTERS * 4, kernel_size = (3, 3)\
                                                     , strides = (1, 1), padding = (1, 1)\
                                                     , activation = 'relu')
        # 128 * 14 * 14 
        self.conv_trans_2 = gluon.nn.Conv2DTranspose(channels = FILTERS * 2, kernel_size = (3, 3)\
                                                     , strides = (2, 2), padding = (1, 1)\
                                                     , activation = 'relu')
        # 64 * 28 * 28
        self.conv_trans_3 = gluon.nn.Conv2DTranspose(channels = FILTERS, kernel_size = (3, 3)\
                                                     , strides = (2, 2), padding = (1, 1)\
                                                     , activation = 'relu')
        # 1 * 28 * 28
        self.conv_trans_4 = gluon.nn.Conv2DTranspose(channels = IMG_CHNS, kernel_size = (3, 3)\
                                                     , strides = (1, 1), padding = (1, 1)\
                                                     , activation = 'sigmoid')
    
    def forward(self, x):
        x = self.dense_h(x)
        x = self.dense_up(x)
        x = x.reshape((FILTERS * 4, 7, 7))
        print('.... %s'%str(x.shape))
        x = self.conv_trans_1(x)
        x = self.conv_trans_2(x)
        x = self.conv_trans_3(x)
        x = self.conv_trans_4(x)
        return x

## Check decoder dimensions

In [65]:
test_input.shape

(128, 1)

In [63]:
dec(test_input)

.... (256, 7, 7)
infer_shape error. Arguments:
  data: (256, 7, 7)


MXNetError: Error in operator conv35_fwd: [20:36:40] src/operator/./deconvolution-inl.h:569: Check failed: dshape.ndim() == 4U (3 vs. 4) Input data should be 4D in batch-num_filter-y-x

Stack trace returned 10 entries:
[bt] (0) /home/kionkim/anaconda3/envs/kion_venv_mxnet/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x28965c) [0x7f9799daa65c]
[bt] (1) /home/kionkim/anaconda3/envs/kion_venv_mxnet/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x2766582) [0x7f979c287582]
[bt] (2) /home/kionkim/anaconda3/envs/kion_venv_mxnet/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x2688337) [0x7f979c1a9337]
[bt] (3) /home/kionkim/anaconda3/envs/kion_venv_mxnet/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x249c52f) [0x7f979bfbd52f]
[bt] (4) /home/kionkim/anaconda3/envs/kion_venv_mxnet/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x249f010) [0x7f979bfc0010]
[bt] (5) /home/kionkim/anaconda3/envs/kion_venv_mxnet/lib/python3.6/site-packages/mxnet/libmxnet.so(MXSymbolInferShape+0x1539) [0x7f979bf31279]
[bt] (6) /home/kionkim/anaconda3/envs/kion_venv_mxnet/lib/python3.6/lib-dynload/../../libffi.so.6(ffi_call_unix64+0x4c) [0x7f97d8a1cec0]
[bt] (7) /home/kionkim/anaconda3/envs/kion_venv_mxnet/lib/python3.6/lib-dynload/../../libffi.so.6(ffi_call+0x22d) [0x7f97d8a1c87d]
[bt] (8) /home/kionkim/anaconda3/envs/kion_venv_mxnet/lib/python3.6/lib-dynload/_ctypes.cpython-36m-x86_64-linux-gnu.so(_ctypes_callproc+0x2ce) [0x7f97d8c3182e]
[bt] (9) /home/kionkim/anaconda3/envs/kion_venv_mxnet/lib/python3.6/lib-dynload/_ctypes.cpython-36m-x86_64-linux-gnu.so(+0x12265) [0x7f97d8c32265]


In [60]:
dec = Decoder(HIDDEN_DIM, FILTERS)
print(dec)
dec.collect_params().initialize(mx.init.Xavier(magnitude=2.24), ctx = gpu_ctx)
test_input = nd.random_normal(0,1,shape= (HIDDEN_DIM, 1), ctx = gpu_ctx)
dec(test_input)
dec.collect_params().initialize(mx.init.Xavier(magnitude=2.24), ctx = gpu_ctx)

Decoder(
  (dense_h): Dense(None -> 128, Activation(relu))
  (dense_up): Dense(None -> 12544, Activation(relu))
  (conv_trans_1): Conv2DTranspose(256 -> 0, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv_trans_2): Conv2DTranspose(128 -> 0, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (conv_trans_3): Conv2DTranspose(64 -> 0, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (conv_trans_4): Conv2DTranspose(1 -> 0, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
.... (256, 7, 7)
infer_shape error. Arguments:
  data: (256, 7, 7)


MXNetError: Error in operator conv35_fwd: [21:09:24] src/operator/./deconvolution-inl.h:569: Check failed: dshape.ndim() == 4U (3 vs. 4) Input data should be 4D in batch-num_filter-y-x

Stack trace returned 10 entries:
[bt] (0) /home/kionkim/anaconda3/envs/kion_venv_mxnet/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x28965c) [0x7f9799daa65c]
[bt] (1) /home/kionkim/anaconda3/envs/kion_venv_mxnet/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x2766582) [0x7f979c287582]
[bt] (2) /home/kionkim/anaconda3/envs/kion_venv_mxnet/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x2688337) [0x7f979c1a9337]
[bt] (3) /home/kionkim/anaconda3/envs/kion_venv_mxnet/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x249c52f) [0x7f979bfbd52f]
[bt] (4) /home/kionkim/anaconda3/envs/kion_venv_mxnet/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x249f010) [0x7f979bfc0010]
[bt] (5) /home/kionkim/anaconda3/envs/kion_venv_mxnet/lib/python3.6/site-packages/mxnet/libmxnet.so(MXSymbolInferShape+0x1539) [0x7f979bf31279]
[bt] (6) /home/kionkim/anaconda3/envs/kion_venv_mxnet/lib/python3.6/lib-dynload/../../libffi.so.6(ffi_call_unix64+0x4c) [0x7f97d8a1cec0]
[bt] (7) /home/kionkim/anaconda3/envs/kion_venv_mxnet/lib/python3.6/lib-dynload/../../libffi.so.6(ffi_call+0x22d) [0x7f97d8a1c87d]
[bt] (8) /home/kionkim/anaconda3/envs/kion_venv_mxnet/lib/python3.6/lib-dynload/_ctypes.cpython-36m-x86_64-linux-gnu.so(_ctypes_callproc+0x2ce) [0x7f97d8c3182e]
[bt] (9) /home/kionkim/anaconda3/envs/kion_venv_mxnet/lib/python3.6/lib-dynload/_ctypes.cpython-36m-x86_64-linux-gnu.so(+0x12265) [0x7f97d8c32265]


In [None]:
def vae_loss(x, x_gen):
    x = K.flatten(x)
    x_gen = K.flatten(x_gen)
    
    # D_KL(q(z|X) || p(z))
    kl_loss = -0.5 * K.sum(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
    # E(log p(X|z))
    recon_loss = K.sum(K.binary_crossentropy(x_gen, x), axis=-1)

    return kl_loss + recon_loss

In [None]:
def sample_z(args):
    mean, log_var = args
    eps = nd.random_normal(shape=(BATCH_SIZE, Z_DIM), mean=0., stddev=1.)
    return mean + nd.exp(log_var / 2) * eps

class sample_layer(gluon.Block):
    def __init__(self, *args):
        super(sample_layer, self).__init__(*args)
        self.args = args

    def forward(self, x):
        return sample_z(self.args)

In [None]:
class VAE(gluon.Block):
    def __init__(self, **kwargs):
        super(VAE, self).__init__(FILTERS, HIDDEN_DIM, BATCH_SIZE)
        self.enc = Encoder(FILTERS, HIDDEN_DIM, BATCH_SIZE)
        self.dec = Decoder(HIDDEN_DIM, FILTERS)
        
    def forward(self, x)
        x = self.enc(x)
        z_mean = gluon.nn.Dense(x)
        log_z_var = gluon.nnDense(x)
        self.z = sample_layer([z_mean, log_z_var])
        x_gen = Decoder(z)
        return x_gen

In [None]:
EPOCHS = 10
train_data = mx.gluon.data.DataLoader(mx.gluon.data.vision.MNIST(train=True, transform=transform), \
                                      BATCH_SIZE, shuffle=True)

test_data = mx.gluon.data.DataLoader(mx.gluon.data.vision.MNIST(train=False, transform=transform), \
                              BATCH_SIZE, shuffle=False)

In [None]:
trainer = gluon.Trainer(VAE.collect_params(), 'adam', {'learning_rate': .001})

In [None]:
for e in EPOCHS:
    for i, data in enumerate(train_data):
        x = data.as_in_context(gpu_ctx)
        
    with autograd.record():
        x_gen = VAE(x)
        loss = vae_loss(x, x_gen)
    loss.backward()
        