In [None]:
class AttentionCell(RNNCell):
    def __init__(self, cell, memory, mask=None, controller=None, mapper=None, input_keep_prob=1.0, is_train=None):
        """
        Early fusion attention cell: uses the (inputs, state) to control the current attention.

        :param cell:
        :param memory: [N, M, m]
        :param mask:
        :param controller: (inputs, prev_state, memory) -> memory_logits
        """
        self._cell = cell
        self._memory = memory
        self._mask = mask
        self._flat_memory = flatten(memory, 2)
        self._flat_mask = flatten(mask, 1)
        if controller is None:
            controller = AttentionCell.get_linear_controller(True, is_train=is_train)
        self._controller = controller
        if mapper is None:
            mapper = AttentionCell.get_concat_mapper()
        elif mapper == 'sim':
            mapper = AttentionCell.get_sim_mapper()
        self._mapper = mapper

    @property
    def state_size(self):
        return self._cell.state_size

    @property
    def output_size(self):
        return self._cell.output_size

    def __call__(self, inputs, state, scope=None):
        with tf.variable_scope(scope or "AttentionCell"):
            memory_logits = self._controller(inputs, state, self._flat_memory)
            sel_mem = softsel(self._flat_memory, memory_logits, mask=self._flat_mask)  # [N, m]
            new_inputs, new_state = self._mapper(inputs, state, sel_mem)
            return self._cell(new_inputs, state)

    @staticmethod
    def get_double_linear_controller(size, bias, input_keep_prob=1.0, is_train=None):
        def double_linear_controller(inputs, state, memory):
            """

            :param inputs: [N, i]
            :param state: [N, d]
            :param memory: [N, M, m]
            :return: [N, M]
            """
            rank = len(memory.get_shape())
            _memory_size = tf.shape(memory)[rank-2]
            tiled_inputs = tf.tile(tf.expand_dims(inputs, 1), [1, _memory_size, 1])
            if isinstance(state, tuple):
                tiled_states = [tf.tile(tf.expand_dims(each, 1), [1, _memory_size, 1])
                                for each in state]
            else:
                tiled_states = [tf.tile(tf.expand_dims(state, 1), [1, _memory_size, 1])]

            # [N, M, d]
            in_ = tf.concat([tiled_inputs] + tiled_states + [memory], axis=2)
            out = double_linear_logits(in_, size, bias, input_keep_prob=input_keep_prob,
                                       is_train=is_train)
            return out
        return double_linear_controller

    @staticmethod
    def get_linear_controller(bias, input_keep_prob=1.0, is_train=None):
        def linear_controller(inputs, state, memory):
            rank = len(memory.get_shape())
            _memory_size = tf.shape(memory)[rank-2]
            tiled_inputs = tf.tile(tf.expand_dims(inputs, 1), [1, _memory_size, 1])
            if isinstance(state, tuple):
                tiled_states = [tf.tile(tf.expand_dims(each, 1), [1, _memory_size, 1])
                                for each in state]
            else:
                tiled_states = [tf.tile(tf.expand_dims(state, 1), [1, _memory_size, 1])]

            # [N, M, d]
            in_ = tf.concat([tiled_inputs] + tiled_states + [memory], axis=2)
            out = linear(in_, 1, bias, squeeze=True, input_keep_prob=input_keep_prob, is_train=is_train)
            return out
        return linear_controller

    @staticmethod
    def get_concat_mapper():
        def concat_mapper(inputs, state, sel_mem):
            """

            :param inputs: [N, i]
            :param state: [N, d]
            :param sel_mem: [N, m]
            :return: (new_inputs, new_state) tuple
            """
            return tf.concat(axis=1, values=[inputs, sel_mem]), state
        return concat_mapper

    @staticmethod
    def get_sim_mapper():
        def sim_mapper(inputs, state, sel_mem):
            """
            Assume that inputs and sel_mem are the same size
            :param inputs: [N, i]
            :param state: [N, d]
            :param sel_mem: [N, i]
            :return: (new_inputs, new_state) tuple
            """
            return tf.concat(axis=1, values=[inputs, sel_mem, inputs * sel_mem, tf.abs(inputs - sel_mem)]), state
        return sim_mapper