In [213]:
import tensorflow as tf
import numpy as np
from quebap.projects.modelF.structs import FrozenIdentifier
%load_ext autoreload
%autoreload 2

In [51]:
tf.reset_default_graph()

def tok(string):
    return string.split(" ")

statements = [tok(w) for w in [
    "eagle can fly",
    "eagle is a bird",
    "bird can fly",
    "duck can fly"
]]

questions = [tok(w) for w in [
    "duck can _"
]]

WHITESPACE = '[WS]'
words = { word for statement in statements + questions for word in statement }
words.add(WHITESPACE)
vocab = FrozenIdentifier(words)
embeddings = tf.diag(tf.ones(len(vocab)))
whitespace_repr = tf.gather(embeddings, vocab[WHITESPACE]) # [repr_dim]
cost_for_remainder = 1 - tf.gather(embeddings, vocab['_']) - whitespace_repr

def match_and_extract(statement, question):
    # statement: [batch_size, max_length, repr_dim]
    # questions: [batch_size, max_length, repr_dim]
    # result: [batch_size, max_length, repr_dim], [batch_size]
    # for each batch, for each token, check if both tokens are identical. If so use WHITESPACE, else keep only statement
    logits = tf.reduce_sum(statement * question, 2) #[batch_size, max_length]
    match_scores = tf.maximum(tf.minimum(logits,1.0),0.0) # alternatively: sigmoids, look for distance etc.
    
    expanded_match_scores = tf.expand_dims(match_scores, 2)
    whitespace_where_match = whitespace_repr * expanded_match_scores # [batch_size, max_length, repr_dim]
    statement_elsewhere = (1.0 - expanded_match_scores) * statement
    extraction = whitespace_where_match + statement_elsewhere
    question_elsewhere = (1.0 - expanded_match_scores) * question #[batch_size, max_length, repr_dim]
    question_leftover = whitespace_where_match + question_elsewhere
    costs = tf.reduce_sum(question_leftover * cost_for_remainder, [1,2])
    return extraction, costs,question_leftover
#     return statement_elsewhere

def repr_text_batch(texts):
    max_length = max([len(text) for text in texts])
    result = [[vocab[text[i]] if i < len(text) else vocab[WHITESPACE] for i in range(0, max_length)] for text in texts]
    return result

def embed_statements(statements):
    return tf.gather(embeddings, statements)
        
statement_placeholder = tf.placeholder(tf.int32,(None,None))
question_placeholder = tf.placeholder(tf.int32, (None,None))

extract_result = match_and_extract(embed_statements(statement_placeholder), embed_statements(question_placeholder))

sess = tf.Session()
sess.run(whitespace_repr)

def to_feed_dict(statements, questions):    
    text_repr = repr_text_batch(statements + questions)
    statement_repr, question_repr = text_repr[:len(statements)], text_repr[len(statements):]
    return {statement_placeholder:statement_repr, question_placeholder: question_repr}

sess.run((extract_result,whitespace_repr,cost_for_remainder), 
         feed_dict=to_feed_dict(statements[3:4], questions[:1]))

sess.run(embeddings[:,1])

array([ 0.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.], dtype=float32)

In [212]:
## import numpy as np

class CompactifyCell(tf.nn.rnn_cell.RNNCell):
    
    def __init__(self, max_compact_length, input_dim, zero_result=None):
        # zero_result: [batch_size, max_compact_length, input_dim]
        self._max_compact_length = max_compact_length
        self._input_dim = input_dim
        self._shift_matrix = np.zeros((max_compact_length,max_compact_length))
        for i in range(0, max_compact_length):
            self._shift_matrix[i,(i+1)%max_compact_length] = 1.0
        
        self._zero_result = zero_result
        
    
    def __call__(self, inputs, state, scope=None):
        result, counter = state
        result_matrix = tf.reshape(result, (-1, self._max_compact_length,self._input_dim))
        # result matrix # [batch_size, input_dim, max_compact_length]
        # counter [batch_size, max_compact_length]
        # inputs [batch_size, input_dim + 1]
        # mask = [batch_size, 1]
        mask = inputs[:,0:1] # [batch_size, 1] 
        input_tokens = inputs[:,1:] # [batch_size, input_dim]
        
        input_at_counter = tf.expand_dims(input_tokens, 1) * tf.expand_dims(counter, 2) 
        # [batch_size, max_compact_length, input_dim]
        expanded_mask = tf.expand_dims(mask,1)
        
#         new_result_matrix = mask * (result_matrix + input_at_counter) + (1.0 - mask) * result_matrix 
        new_result_matrix = expanded_mask * (result_matrix + input_at_counter) + (1.0 - expanded_mask) * result_matrix
        new_counter = mask * tf.matmul(counter,tf.constant(self._shift_matrix,dtype=tf.float32))+ (1.0 - mask) * counter
#         new_counter = tf.matmul(counter,tf.constant(self._shift_matrix,dtype=tf.float32)) 
#         new_counter = counter
        
        new_result = tf.reshape(new_result_matrix, tf.shape(result))
#         new_result = tf.reshape(input_at_counter, tf.shape(result))
        
        return new_result, (new_result, new_counter)
    
    def zero_state(self, batch_size, dtype):
        zero_result = tf.zeros((batch_size,self._input_dim * self._max_compact_length)) \
            if self._zero_result is None else tf.reshape(self._zero_result, (batch_size,self._input_dim * self._max_compact_length))
        zero_counter = tf.concat(1, [tf.ones((batch_size,1)), tf.zeros((batch_size,self._max_compact_length-1))])
        return (zero_result, zero_counter)
    
    @property
    def state_size(self):
        return (self._input_dim * self._max_compact_length, self._max_compact_length)

    @property
    def output_size(self):
        return self._input_dim * self._max_compact_length
    
cell = CompactifyCell(2,3)

inputs = tf.constant([
    [[1., 0.5, 0, -1], [1., 0, 2, -1]],
    [[0., 3, 0, -1], [1., 0, 4, -1]]
])

outputs, (result,counter) = tf.nn.dynamic_rnn(
        cell=cell,
        dtype=tf.float32,
        sequence_length=(2,2),
        inputs=inputs)

sess.run((result,counter))
# sess.run(inputs[:,0,0:1])



(array([[ 0.5,  0. , -1. ,  0. ,  2. , -1. ],
        [ 0. ,  4. , -1. ,  0. ,  0. ,  0. ]], dtype=float32),
 array([[ 1.,  0.],
        [ 0.,  1.]], dtype=float32))