In [5]:
import tensorflow as tf
import numpy as np
from quebap.projects.modelF.structs import FrozenIdentifier


In [46]:
tf.reset_default_graph()

def tok(string):
    return string.split(" ")

statements = [tok(w) for w in [
    "eagle can fly",
    "eagle is a bird",
    "bird can fly",
    "duck can fly"
]]

questions = [tok(w) for w in [
    "duck can _"
]]

WHITESPACE = '[WS]'
words = { word for statement in statements + questions for word in statement }
words.add(WHITESPACE)
vocab = FrozenIdentifier(words)
embeddings = tf.diag(tf.ones(len(vocab)))
whitespace_repr = tf.gather(embeddings, vocab[WHITESPACE]) # [repr_dim]
cost_for_remainder = 1 - tf.gather(embeddings, vocab['_']) - whitespace_repr

def match_and_extract(statement, question):
    # statement: [batch_size, max_length, repr_dim]
    # questions: [batch_size, max_length, repr_dim]
    # result: [batch_size, max_length, repr_dim], [batch_size]
    # for each batch, for each token, check if both tokens are identical. If so use WHITESPACE, else keep only statement
    logits = tf.reduce_sum(statement * question, 2) #[batch_size, max_length]
    match_scores = tf.maximum(tf.minimum(logits,1.0),0.0) # alternatively: sigmoids, look for distance etc.
    
    expanded_match_scores = tf.expand_dims(match_scores, 2)
    whitespace_where_match = whitespace_repr * expanded_match_scores # [batch_size, max_length, repr_dim]
    statement_elsewhere = (1.0 - expanded_match_scores) * statement
    extraction = whitespace_where_match + statement_elsewhere
    question_elsewhere = (1.0 - expanded_match_scores) * question #[batch_size, max_length, repr_dim]
    question_leftover = whitespace_where_match + question_elsewhere
    costs = tf.reduce_sum(question_leftover * cost_for_remainder, [1,2])
    return extraction, costs,question_leftover
#     return statement_elsewhere

def repr_text_batch(texts):
    max_length = max([len(text) for text in texts])
    result = [[vocab[text[i]] if i < len(text) else vocab[WHITESPACE] for i in range(0, max_length)] for text in texts]
    return result

def embed_statements(statements):
    return tf.gather(embeddings, statements)
        
statement_placeholder = tf.placeholder(tf.int32,(None,None))
question_placeholder = tf.placeholder(tf.int32, (None,None))

extract_result = match_and_extract(embed_statements(statement_placeholder), embed_statements(question_placeholder))

sess = tf.Session()
sess.run(whitespace_repr)

def to_feed_dict(statements, questions):    
    text_repr = repr_text_batch(statements + questions)
    statement_repr, question_repr = text_repr[:len(statements)], text_repr[len(statements):]
    return {statement_placeholder:statement_repr, question_placeholder: question_repr}

sess.run((extract_result,whitespace_repr,cost_for_remainder), 
         feed_dict=to_feed_dict(statements[3:4], questions[:1]))



((array([[[ 0.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
          [ 0.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
          [ 0.,  0.,  0.,  0.,  0.,  0.,  1.,  0.,  0.]]], dtype=float32),
  array([ 0.], dtype=float32),
  array([[[ 0.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
          [ 0.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
          [ 0.,  0.,  0.,  0.,  0.,  1.,  0.,  0.,  0.]]], dtype=float32)),
 array([ 0.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.], dtype=float32),
 array([ 1.,  0.,  1.,  1.,  1.,  0.,  1.,  1.,  1.], dtype=float32))