In [None]:
import mxnet as mx
import pickle
import numpy as np

In [None]:
model = mx.module.Module.load('demo3', 0, context=mx.gpu(0))
model.bind([('data', (1, 3, 112, 112))], None, for_training=False)

In [None]:
path = '/home/wuyuxiang/faces_ms1m_112x112/' + 'agedb_30.bin'
bins, issame_list = pickle.load(open(path, 'rb'))
data_list = []
for flip in [0, 1]:
    data = mx.nd.empty((len(issame_list) * 2, 3, 112, 112))
    data_list.append(data)
for i in range(len(issame_list) * 2):
    _bin = bins[i]
    img = mx.image.imdecode(_bin)
    img = mx.nd.transpose(img, axes=(2, 0, 1))
    for flip in [0, 1]:
        if flip == 1:
            img = mx.nd.flip(data=img, axis=2)
        data_list[flip][i] = img
    if i % 1000 == 0:
        print('loading bin', i)
print(data_list[0].shape)

In [None]:
img = data_list[0][0].asnumpy()
img = img[np.newaxis, :, :, :]
img = mx.nd.array(img) 
db = mx.io.DataBatch(data=(img,), provide_data=[('data', (1, 3, 112, 112))])

In [None]:
model.forward(db, is_train=False)

In [None]:
a = model.get_outputs()[0] / mx.nd.norm(model.get_outputs()[0])

In [None]:
b = model.get_outputs()[0] / mx.nd.norm(model.get_outputs()[0])

In [None]:
mx.nd.sum(a*b)

In [None]:
for i in range(len(data_list[0])):
    img = data_list[0][i].asnumpy()
    img = img[np.newaxis, :, :, :]
    img = mx.nd.array(img) 
    db = mx.io.DataBatch(data=(img,), provide_data=[('data', (1, 3, 112, 112))])
    model.forward(db)
    b = model.get_outputs()[0] / mx.nd.norm(model.get_outputs()[0])
    print(mx.nd.sum(a*b))
    

## MetaFace GAN

In [19]:
def conv(data, num_filter, kernel, name, stride=(1,1), slope=0.1, use_px=False):
    pad = (kernel[0]//2, kernel[1]//2)
    num_filter = num_filter*4 if use_px else num_filter
    conv = mx.sym.Convolution(data, kernel=kernel, pad=pad, stride=stride, num_filter=num_filter, name='conv'+name)
    IN = mx.sym.InstanceNorm(conv, name='conv_IN'+name)
    out = mx.sym.LeakyReLU(IN, act_type='leaky', slope=0.1, name='act'+name)
    if use_px:
        px1 = mx.sym.reshape(act, shape=(0, -4, 2, 2, -2), name='px1_'+name) # (B, C, 2, 2, H, W)
        px2 = mx.sym.transpose(px1, axes=(0, 1, 4, 2, 5, 3), name='px2_'+name) # (B, C, H, 2, W, 2)
        out = mx.sym.reshape(px2, shape=(0, 0, -3, -3), name='px3_'+name) # (B, C, H*2, W*2)
    return out
def fc(data, num_hidden, name):
    fc1 = mx.sym.FullyConnected(data, num_hidden=num_hidden, no_bias=True, name='fc'+name)
    IN = mx.sym.InstanceNorm(fc1, name='fc_IN'+name)
    out = mx.sym.Activation(IN, act_type='relu', name='act'+name)
    return out

### Decoder

In [12]:
def face_decoder():
    # Input face (B, 3, 112, 112)
    real_faces = mx.sym.Variable(name='real_faces_decoder')
    # Input vec (B, 512)
    vec = mx.sym.Variable(name='decoder_vec')
    vec_reshape0 = mx.sym.reshape(vec, shape=(0, 0, 1, 1), name='decoder_vec_reshape0')
    vec_norm0 = mx.sym.L2Normalization(vec_reshape0, mode='instance', name='decoder_vec_norm0')
    # 1. Deconv (B, 512, 7, 7)
    deconv1 = mx.sym.Deconvolution(vec_norm0, kernel=(7,7), num_filter=512, no_bias=True, name='decoder_deconv1')
    deconv_IN1 = mx.sym.InstanceNorm(deconv1, name='decoder_deconv_IN1')
    deconv_act1 = mx.sym.LeakyReLU(deconv_IN1, act_type='leaky', slope=0.1, name='decoder_deconv_act1')
    # 2. PX (B, 256, 14, 14)
    px2 = conv(deconv_act1, 256, kernel=(3,3), name='decoder_2', use_px=True)
    # 3. PX (B, 128, 28, 28)
    px3 = conv(px2, 128, kernel=(3,3), name='decoder_3', use_px=True)
    # 4. PX (B, 64, 56, 56)
    px4 = conv(px3, 64, kernel=(3,3), name='decoder_4', use_px=True)
    # 5. PX (B, 32, 112, 112)
    px5 = conv(px4, 32, kernel=(3,3), name='decoder_5', slope=0.2, use_px=True)
    # 6. Conv(B, 32, 112, 112) 5x5
    conv6 = conv(px5, 32, kernel=(5,5), name='decoder_6', slope=0.2, use_px=False)
    # 7. mask(B, 1, 112, 112)
    mask1 = mx.sym.Convolution(conv6, kernel=(5,5), pad=(2,2), stride=(1,1), num_filter=1, name='decoder_mask_conv')
    mask = mx.sym.sigmoid(mask1, name='decoder_mask_conv_sigmoid')
    # 7. bgr(B, 3, 112, 112)
    bgr1 = mx.sym.Convolution(conv6, kernel=(5,5), pad=(2,2), stride=(1,1), num_filter=3, name='decoder_bgr_conv')
    bgr = mx.sym.tanh(bgr1, name='decoder_bgr_conv_tanh')
    # 8. mask*bgr + (1-mask)*real_faces
    fake = mx.sym.broadcast_mul(mask, bgr) + mx.sym.broadcast_mul(mx.sym.broadcast_sub(1, mask), real_faces)
    return fake

### Discriminator 1

In [16]:
def Discr1():
    # lam (B, 1)
    lam = mx.sym.Variable(name='d1_lam')
    # fake (B, 3, 112, 112)
    fake = mx.sym.Variable(name='d1_fake')
    # real_faces (B, 3, 112, 112)
    real_faces = mx.sym.Variable(name='d1_real_faces')
    # d1_input = lam*concat(real, real) + (1-lam)*(fake, real)  ==>  (B, 6, 112, 112)
    d1_input = mx.sym.broadcast_mul(lam, mx.sym.concat(real, real, dim=1)) + mx.sym.broadcast_mul(mx.sym.broadcast_sub(1, lam), mx.sym.concat(fake, real_faces, dim=1))
    # 1. Conv(3, 64, 64)
    conv1 = conv(d1_input, 3, kernel=(3,3), stride=(2,2), slope=0.2, name='discr1_1', use_px=False)
    # 2. Conv(64, 32, 32)
    conv2 = conv(conv1, 64, kernel=(3,3), stride=(2,2), slope=0.2, name='discr1_2', use_px=False)
    # 3. Conv(128, 16, 16)
    conv3 = conv(conv2, 128, kernel=(3,3), stride=(2,2), slope=0.2 name='discr1_3', use_px=False)
    # 4. Conv(256, 8, 8)
    conv4 = conv(conv3, 256, kernel=(3,3), stride=(2,2), slope=0.2, name='discr1_4', use_px=False)
    # 5. Conv(512, 4, 4)
    out = conv(conv4, 512, kernel=(3,3), stride=(2,2), slope=0.2, name='discr1_5', use_px=False)
    return out

### Discriminator 2 (perceptual adversarial loss)

In [18]:
def Discr2():
    # d2_input (B, 3, 112, 112)
    d2_input = mx.sym.Variable(name='d2_input')
    # 1. Conv55 (B, 32, 64, 64) 0.1
    conv1 = conv(d2_input, 32, kernel=(5,5), stride=(2,2), slope=0.1, name='discr2_1', use_px=False)
    # 2. Conv (B, 64, 32, 32) 0.1
    conv2 = conv(conv1, 64, kernel=(3,3), stride=(2,2), slope=0.1, name='discr2_2', use_px=False)
    # 3. Conv (B, 128, 16, 16) 0.2
    conv3 = conv(conv2, 128, kernel=(3,3), stride=(2,2), slope=0.2, name='discr2_3', use_px=False)
    # 4. Conv (B, 256, 8, 8) 0.2
    conv4 = conv(conv3, 256, kernel=(3,3), stride=(2,2), slope=0.2, name='discr2_4', use_px=False)
    # 5. Conv (B, 512, 4, 4) 0.2
    conv5 = conv(conv4, 512, kernel=(3,3), stride=(2,2), slope=0.2, name='discr2_5', use_px=False)
    # 6. Conv33 (B, 1, 4, 4)
    conv6 = mx.sym.Convolution(conv5, num_filter=1, kernel=(3,3), stride=(1,1), pad=(1,1), name='discr2_6')
    
    out = mx.sym.Group([conv3, conv4, conv5, conv6])
    return out

### Discriminator 3 

In [20]:
def Discr3():
    # d3_input (B, 512)
    d3_input = mx.sym.Variable(name='d3_input')
    # 1. fc (B, 256)
    fc1 = fc(d3_input, 256, 'discr3_1')
    # 2. fc (B, 128)
    fc2 = fc(fc1, 128, 'discr3_2')
    # 3. fc (B, 32)
    fc3 = fc(fc2, 32, 'discr3_3')
    # 4. fc (B, 1)
    fc4 = fc(fc3, 1, 'discr3_4')
    
    return fc4

### Loss Module

In [None]:
def loss_module():
    # d1 output  [(B, 512, 4, 4)] [0]
    d1_output = mx.sym.Variable('d1_output')
    # d2 output  [(B, 128, 16, 16), (B, 256, 8, 8), (B, 512, 4, 4), (B, 1, 4, 4)]
    d2_output0 = mx.sym.Variable('d2_output0')
    d2_output1 = mx.sym.Variable('d2_output1')
    d2_output2 = mx.sym.Variable('d2_output2')
    d2_output3 = mx.sym.Variable('d2_output3')
    # d3 output  [(B, 1)]
    d3_output = mx.sym.Variable('d3_output')
    