First, we set up the MNIST dataset and load everything we need.

In [1]:
BATCH_SIZE = 16
NUM_HIDDEN_UNITS= 16
OUTPUT_NUM_HIDDEN_UNITS = 32
INPUT_RECURRENT_LENGTH = 7
OUTPUT_RECURRENT_LENGTH = 2
LEARNING_RATE = 1e-3
WEIGHTS_INITIALIZER = tf.contrib.layers.xavier_initializer()

# image_height = image_width = 28, num_channels = 1
dataset, image_height, image_width, num_channels, next_train_batch, next_test_batch = load_dataset('mnist')
train_data = dataset.train
test_data = dataset.test



NameError: name 'load_dataset' is not defined

Next we define placeholders for the input to our network and a utility function to get the shape of a tensor:

In [2]:
inputs = tf.placeholder(tf.float32, [None, image_height, image_width, 1])

def get_shape(tensor):
    return tensor.get_shape().as_list()

NameError: name 'tf' is not defined

Next, we create a function to perform 2-D image convolutions with the masking procedure outlined above. We also create a 1-D convolution for applying the 1-D kernels to the skewed image in the Diagonal BiLSTM.

In [3]:
def conv2d(inputs, num_outputs, kernel_shape, mask_type='A'):
    batch_size, image_height, image_width, num_channels = get_shape(inputs)

    kernel_height, kernel_width = kernel_shape
    # get the location of the pixel being predicted by this kernel
    center_height, center_width = kernel_height // 2, kernel_width // 2
    
    # initialize kernel weights
    weights_shape = [kernel_height, kernel_width, 1, num_outputs]
    weights = tf.get_variable('weights', weights_shape, tf.float32, WEIGHTS_INITIALIZER)
    
    # create and apply the masks to the convolution
    # set all pixels below and all to the right of the center to 0
    mask = np.ones([kernel_height, kernel_width, 1, num_outputs], dtype=np.float32)
    mask[center_height, center_width+1:, :, :] = 0.
    mask[center_height+1:, :, :, :] = 0.
    
    # for type A masks, we do not allow self-connections
    if mask_type == 'A':
        mask[center_height, center_width, :, :] = 0.
        
    # apply the mask
    weights *= tf.constant(mask, dtype=tf.float32)
    
    # apply the convolution
    outputs = tf.nn.conv2d(inputs, weights, [1, 1, 1, 1], padding='SAME')
    return outputs

def conv1d(inputs, num_outputs, kernel_height):
    batch_size, image_height, image_width, num_channels = get_shape(inputs)
    kernel_height, kernel_width = kernel_height, 1
    
    weights_shape = [kernel_height, kernel_width, 1, num_outputs]
    weights = tf.get_variable('weights', weights_shape, tf.float32, WEIGHT_INITIALIZER)
    outputs = tf.nn.conv2d(inputs, weights, [1, 1, 1, 1], padding='SAME')
    return outputs


Now we we need to construct the special Diagonal BiLSTM cells.
To do so, we have to construct a Diagonal LSTM cell first.

In [None]:
class DiagonalLSTMCell(tf.python.ops.rnn_cell.RNNCell):
    def __init__(self, hidden_dims, height):
        self._num_unit_shards = 1
        self._forget_bias = 1.0
        self._height = height
        self._hidden_dims = hidden_dims
        self._num_units = hidden_dims * height
        self._state_size = self._num_units * 2
        self._output_size = self._num_units
        
    @property
    def state_size(self):
        return self._state_size
    
    @property
    def output_size(self):
        return self._output_size
    
    def __call__(self, input_to_state, state):
        c_prev = tf.slice(state, [0, 0], [-1, self._num_units])
        hidden_prev = tf.slice(state, [0, self._num_units], [-1, self._num_units])
        
        # input_to_state shape: [batch, 4 * height * hidden_dims]
        input_size = input_to_state.get_shape().with_rank(2)[1]
        
        if input_size.value is None:
            raise ValueError('Count not infer input size from input_to_state.')
        
        conv1d_inputs = tf.reshape(hidden_prev, [-1, self._height, 1, self._hidden_dims])
        conv_state_to_state = conv1d(conv1d_inputs, 4 * self._hidden_dims, 2)
        state_to_state = tf.reshape(conv_state_to_state, [-1, 4 * self._height * self._hidden_dims])
        
        lstm_matrix = tf.sigmoid(state_to_state + input_to_state)
        i, g, f, o = tf.split(1, 4, lstm_matrix)
        c = f * c_prev + i * g
        h = tf.mul(o, tf.tanh(c))
        
        new_state = tf.concat(1, [c, h])
        return h, new_state
    
def diagonal_lstm(inputs, hidden_dims):
    skewed_inputs = skew(inputs)
    input_to_state = conv2d(skewed_inputs, 4 * hidden_dims, 1, 1)
    column_wise_inputs = tf.transpose(input_to_state, [0, 2, 1, 3])
    
    batch, width, height, channel = get_shape(column_wise_inputs)
    rnn_inputs = tf.reshape(column_wise_inputs, [-1, width, height * channel])
    
    split_rnn_inputs = tf.split(split_dim=1, num_split=width, value=rnn_inputs)
    rnn_input_list = [tf.squeeze(rnn_input, squeeze_dims=[1]) for rnn_input in split_rnn_inputs]
    
    cell = DiagonalLSTMCell(hidden_dims, height, channel)
    
    output_list, state_list = tf.nn.rnn(cell, inputs=rnn_input_list, dtype=tf.float32)
    packed_outputs = tf.pack(output_list, 1)
    width_first_outputs = tf.reshape(packed_outputs, [-1, width, height, hidden_dims])
    
    skewed_outputs = tf.transpose(width_first_outputs, [0, 2, 1, 3])
    outputs = unskew(skewed_outputs)
    return outputs

Now we can use the diagonal LSTM cell to produce a BiLSTM cell.

In [None]:
def diagonal_bilstm(inputs, hidden_dims, use_residual=False):
    def reverse(inputs):
        return tf.reverse(inputs, [False, False, True, False])
    
    output_state_forward = diagonal_lstm(inputs, hidden_dims)
    output_state_backward = reverse(diagonal_lstm(reverse(inputs), hidden_dims))
    
    if use_residual:
        residual_state_forward = conv2d(output_state_forward, hidden_dims * 2, 1, 1)
        output_state_forward = residual_state_forward + inputs
        
        residual_state_backward = conv2d(output_state_backward, hidden_dim * 2, 1, 1)
        output_state_backward = residual_state_backward + inputs
        
    batch, height, width, channel = get_shape(output_state_backward)
    output_state_backward_except_last = tf.slice(output_state_backward, [0, 0, 0, 0], [-1, height-1, -1, -1])
    output_state_backward_only_last = tf.slice(output_state_backward, [0, height-1, 0, 0], [-,1 1, -1, -1])
    
    dummy_zeros = tf.zeros_like(output_state_backward_only_last)
    output_state_backward_with_last_zeros = tf.concat(1, [output_state_backward_except_last, dummy_zeros])
    
    return output_state_forward + output_state_backward_with_last_zeros
    

Now we can move on to actually cosntructing the network. First we will be performing a 7x7 convolution on the image.

In [4]:
kernel_shape = [7, 7]
if USE_RESIDUALS:
    conv_2d_inputs = conv2d(inputs, 2 * NUM_HIDDEN_UNITS, kernel_shape, 'A')
else:
    conv_2d_inputs = conv2d(inputs, NUM_HIDDEN_UNITS, kernel_shape, 'A')

NameError: name 'inputs' is not defined

Next, we construct all the BiLSTM layers.

In [None]:
last_input = conv_2d_inputs
for i in range(INPUT_RECURRENT_LENGTH):
    last_input = diagonal_bilstm(last_input, NUM_HIDDEN)UNITS, USE_RESIDUALS)
    
for i in range(OUTPUT_RECURRENT_LENGTH):
    kernel_shape = [1, 1]
    conv_layer = conv2d(last_input, OUTPUT_NUM_HIDDEN_UNITS, kernel_shape, 'B')
    recurrent_out = tf.nn.relu(conv_layer)
    last_input = recurrent_out
    
recurrent_out_logits = last_input

We can now apply a final convolution layer to the logits of the recurrent portion of the network to get our predictions for each image.

In [None]:
conv2d_recurrent_logits = conv2d(recurrent_out_logits, 1, 1, 1, 'B')
output = tf.nn.sigmoid(conv2d_recurrent_logits)

Now we define our loss function to optimize and create the update step for the network.

In [None]:
all_losses = tf.nn.sigmoid_cross_entropy_with_logits(conv2d_recurrent_logits, inputs)
loss = tf.reduce_mean(all_losses)

optimizer = tf.train.RMSPropOptimizer(LEARNING_RATE)
gradients_and_vars = optimizer.compute_gradients(loss)

update_step = optimizer.apply_gradients(gradients_and_vars)

def predict(sess, images):
    return sess.run(output, feed_dict={inputs: images})

def test(sess, images, with_update=False):
    if with_update:
        _, cost = sess.run([update_step, loss],
            feed_dict = {inputs: images})
    else:
        cost = sess.run(loss, feed_dict = {inputs: images})
    return cost

def generate_images(sess, batch_shape, starting_pos=[0, 0], staring_image=None):
    if starting_image is not None:
        samples = starting_image.copy()
    else:
        samples = np.zeros(batch_shape, dtype='float32')
        
    batch_size, height, width, channels = batch_shape
    for i in range(starting_pos[1], height):
        for j in range(starting_pos[0], width):
            for k in range(channels):
                next_sample = binarize(predict(sess, samples))
                samples[:, i, j, k] = next_sample[:, i, j, k]
    return samples
        