In [1]:
import tqdm
import numpy as np

In [2]:
data = np.load('data.npy')
data = np.reshape(data,[-1,20,28,28])
train_data = data[:1200,:,:,:]
test_data = data[1200:,:,:,:]

In [3]:
def get_minibatch(nb_batch, nb_episode, nb_class, data):
    dshape = data.shape[2:]
    batch_x = np.zeros((nb_batch, nb_episode) + dshape + (1,))
    batch_test = np.zeros((nb_batch, nb_episode, 2))
    batch_y = np.zeros((nb_batch, nb_episode))
    batch_mask = np.ones((nb_batch, nb_episode))
    nb_class_total = len(data)
    
    for i in range(nb_batch):
        # classes for learning
        classes = np.random.choice(nb_class_total, nb_class, False)
        
        # index to class
        # class_idx is classes[pinds[i]]
        pidx = np.random.permutation(nb_class)
        
        # sample data
        sample = np.random.randint(0, nb_class, nb_episode)
        batch_y[i] = sample
        
        _, first = np.unique(sample, return_index=True)
        mask = np.ones(nb_episode, np.bool)
        mask[first] = False
        batch_mask[i] = mask
        
        for j in range(nb_class):
            idx = (sample == j)
            eidx = np.random.choice(data.shape[1], np.sum(sample == j), False)
            imgs = data[classes[pidx[j]], eidx]
            
            batch_x[i, idx, :, :, 0] = np.rot90(imgs, np.random.randint(4), axes=(1,2))
            batch_test[i, idx, 0] = classes[pidx[j]]
            batch_test[i, idx, 1] = eidx
            
    return batch_x, batch_y, batch_mask, batch_test

In [4]:
nb_batch = 4
nb_episode = 32
nb_class = 5

In [5]:
batch_x, batch_y, batch_m, _ = get_minibatch(nb_batch, nb_episode, nb_class, train_data)

In [6]:
import tensorflow as tf
sess = tf.InteractiveSession()

In [7]:
def make_input(batch_y, nb_class):    
    batch_p = (np.arange(nb_class) == batch_y[:,:-1,None]).astype(int)
    dummy = np.zeros((nb_batch, 1, nb_class), dtype=np.float32)
    return np.concatenate((dummy, batch_p), axis=1)

In [8]:
batch_p = make_input(batch_y, nb_class)

In [9]:
# input image
batch_x.shape

(4, 32, 28, 28, 1)

In [24]:
# Previous Label
batch_p.shape

(4, 32, 5)

In [11]:
batch_y

array([[ 1.,  3.,  4.,  2.,  1.,  2.,  0.,  3.,  1.,  0.,  2.,  0.,  0.,
         0.,  3.,  1.,  0.,  4.,  1.,  2.,  1.,  0.,  2.,  4.,  0.,  3.,
         1.,  0.,  4.,  0.,  2.,  1.],
       [ 1.,  1.,  3.,  4.,  2.,  2.,  4.,  2.,  3.,  0.,  0.,  1.,  3.,
         3.,  0.,  2.,  2.,  3.,  2.,  0.,  3.,  3.,  0.,  2.,  4.,  2.,
         0.,  1.,  2.,  1.,  2.,  0.],
       [ 2.,  1.,  2.,  0.,  2.,  1.,  4.,  0.,  4.,  0.,  1.,  2.,  4.,
         3.,  2.,  3.,  1.,  2.,  3.,  3.,  3.,  1.,  2.,  4.,  1.,  3.,
         3.,  1.,  3.,  2.,  1.,  2.],
       [ 3.,  0.,  2.,  0.,  1.,  1.,  4.,  4.,  3.,  3.,  4.,  4.,  1.,
         3.,  1.,  0.,  1.,  2.,  1.,  4.,  4.,  0.,  2.,  0.,  0.,  1.,
         3.,  1.,  3.,  4.,  2.,  2.]])

In [12]:
print(batch_m)

[[ 0.  0.  0.  0.  1.  1.  0.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.
   1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.]
 [ 0.  1.  0.  0.  0.  1.  1.  1.  1.  0.  1.  1.  1.  1.  1.  1.  1.  1.
   1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.]
 [ 0.  0.  1.  0.  1.  1.  0.  1.  1.  1.  1.  1.  1.  0.  1.  1.  1.  1.
   1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.]
 [ 0.  0.  0.  1.  0.  1.  0.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.
   1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.]]


In [13]:
def embd_net(inp, scope, reuse=False, stop_grad=False):
    nb_episode = int(inp.shape[1])
    
    with tf.variable_scope(scope) as varscope:
        if reuse: 
            varscope.reuse_variables()

        _inp = tf.reshape(inp, [-1, 28, 28, 1])
        cur_input = _inp
        cur_filters = 1
        
        for i in range(4):
            with tf.variable_scope('conv'+str(i)):
                W = tf.get_variable('W', [3, 3, cur_filters, 64])
                beta = tf.get_variable('beta', [64], initializer=tf.constant_initializer(0.0))
                gamma = tf.get_variable('gamma', [64], initializer=tf.constant_initializer(1.0))

                cur_filters = 64
                pre_norm = tf.nn.conv2d(cur_input, W, strides=[1,1,1,1], padding='SAME')
                mean, variance = tf.nn.moments(pre_norm, [0, 1, 2])
                post_norm = tf.nn.batch_normalization(pre_norm, mean, variance, beta, gamma, variance_epsilon = 1e-10)
                conv = tf.nn.relu(post_norm)
                cur_input = tf.nn.max_pool(conv, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding = 'VALID')

        if stop_grad:
            squeezed = tf.squeeze(cur_input, [1,2])
            output = tf.stop_gradient(tf.reshape(squeezed, [-1, nb_episode, 64]))
        else:
            squeezed = tf.squeeze(cur_input, [1,2])
            output = tf.reshape(squeezed, [-1, nb_episode, 64])
            
    return output

In [14]:
def calual_conv_with_activation(inp, nb_input, nb_output, dilation_rate):
        Wf = tf.get_variable('W_filter', [2, nb_input, nb_output])
        bf = tf.get_variable('b_filter', [nb_output])    
        Wg = tf.get_variable('W_gate', [2, nb_input, nb_output])
        bg = tf.get_variable('b_gate', [nb_output])            
        
        x = tf.pad(inp, [[0, 0], [dilation_rate, 0], [0, 0]])
        
        xf = tf.nn.convolution(x, Wf, strides=[1,], dilation_rate=[dilation_rate,], padding='VALID')
        xf = tf.nn.bias_add(xf, bf)
        
        xg = tf.nn.convolution(x, Wg, strides=[1,], dilation_rate=[dilation_rate,], padding='VALID')
        xg = tf.nn.bias_add(xg, bg)
        
        out = tf.tanh(xf) * tf.sigmoid(xg)
        
        return out

In [15]:
def res_block(inp, nb_dim, dilation_rate, scope):
    with tf.variable_scope(scope):
        x = calual_conv_with_activation(inp, nb_dim, nb_dim, dilation_rate)
        x = x + inp
    return x

In [16]:
def dense_block(inp, nb_dim, dilation_rate, scope):
    with tf.variable_scope(scope):
        x = calual_conv_with_activation(inp, nb_dim, 128, dilation_rate)
        x = res_block(x, 128, dilation_rate, 'res_01')
        x = res_block(x, 128, dilation_rate, 'res_02')
        
        x = tf.concat((inp, x), axis=2)
        
        return x

In [17]:
def build_tcml(inp, label, scope, reuse=False, stop_grad=False):
    with tf.variable_scope(scope):
        with tf.variable_scope('preprocess'):
            x = tf.concat((inp, label), axis=2)

        nb_channel = int(x.shape[2])
        x = dense_block(x, nb_channel+0*128, 1, 'dense_01')
        x = dense_block(x, nb_channel+1*128, 2, 'dense_02')
        x = dense_block(x, nb_channel+2*128, 4, 'dense_03')
        x = dense_block(x, nb_channel+3*128, 8, 'dense_04')
        x = dense_block(x, nb_channel+4*128, 16, 'dense_05')
        x = dense_block(x, nb_channel+5*128, 1, 'dense_06',)
        x = dense_block(x, nb_channel+6*128, 2, 'dense_07',)
        x = dense_block(x, nb_channel+7*128, 4, 'dense_08')
        x = dense_block(x, nb_channel+8*128, 8, 'dense_09')
        x = dense_block(x, nb_channel+9*128, 16, 'dense_10')
        
        with tf.variable_scope('postprocess'):
            W1 = tf.get_variable('W1', [1, nb_channel+10*128, 512])
            b1 = tf.get_variable('b1', [512])
            W2 = tf.get_variable('W2', [1, 512, 5])
            b2 = tf.get_variable('b2', [5])

            x = tf.nn.conv1d(x, W1, stride=1, padding='SAME')
            x = tf.nn.bias_add(x, b1)
            x = tf.nn.relu(x)
            
            x = tf.nn.conv1d(x, W2, stride=1, padding='SAME')            
            x = tf.nn.bias_add(x, b2)

        with tf.variable_scope('output'):
            output = tf.nn.softmax(x)

    return output

In [18]:
def build(img, prev_label):
    feature = embd_net(img, 'embd')
    tcml = build_tcml(feature, prev_label, 'TCML')
    
    return tcml

In [19]:
img = tf.placeholder(tf.float32, shape=[None, nb_episode, 28, 28, 1])
prev_label = tf.placeholder(tf.float32, shape=[None, nb_episode, 5])

net = build(img, prev_label)

In [20]:
sess.run(tf.global_variables_initializer())

In [21]:
summary_writer = tf.summary.FileWriter('./log', graph=sess.graph)

In [22]:
feed_dict = {
    img: batch_x,
    prev_label: batch_p
}
ret = sess.run(net, feed_dict=feed_dict)

In [23]:
ret.shape

(4, 32, 5)