In [1]:
import tensorflow as tf
import numpy as np
from quebap.projects.modelF.structs import FrozenIdentifier
%load_ext autoreload
%autoreload 2

In [193]:
tf.reset_default_graph()

def tok(string):
    return string.split(" ")

statements = [tok(w) for w in [
    "eagle can fly",
    "eagle => bird",
    "bird can fly",
    "duck can fly"
]]

questions = [tok(w) for w in [
    "duck can _"
]]

WHITESPACE = '[WS]'
words = { word for statement in statements + questions for word in statement }
words.add(WHITESPACE)
vocab = FrozenIdentifier(words)
embeddings = tf.diag(tf.ones(len(vocab)))
whitespace_repr = tf.gather(embeddings, vocab[WHITESPACE]) # [repr_dim]
cost_for_remainder = 1 - tf.gather(embeddings, vocab['_']) - whitespace_repr

def match_and_extract(statement, question):
    # statement: [batch_size, max_length, repr_dim]
    # questions: [batch_size, max_length, repr_dim]
    # result: [batch_size, max_length, repr_dim], [batch_size]
    # for each batch, for each token, check if both tokens are identical. If so use WHITESPACE, else keep only statement
    logits = tf.reduce_sum(statement * question, 2) #[batch_size, max_length]
    match_scores = tf.maximum(tf.minimum(logits,1.0),0.0) # alternatively: sigmoids, look for distance etc.
    
    expanded_match_scores = tf.expand_dims(match_scores, 2)
    whitespace_where_match = whitespace_repr * expanded_match_scores # [batch_size, max_length, repr_dim]
    statement_elsewhere = (1.0 - expanded_match_scores) * statement
    extraction = whitespace_where_match + statement_elsewhere
    question_elsewhere = (1.0 - expanded_match_scores) * question #[batch_size, max_length, repr_dim]
    question_leftover = whitespace_where_match + question_elsewhere
    costs = tf.reduce_sum(question_leftover * cost_for_remainder, [1,2])
    return extraction, costs,question_leftover
#     return statement_elsewhere

def repr_text_batch(texts):
    max_length = max([len(text) for text in texts])
    result = [[vocab[text[i]] if i < len(text) else vocab[WHITESPACE] for i in range(0, max_length)] for text in texts]
    return result

def embed_statements(statements):
    return tf.gather(embeddings, statements)
        
statement_placeholder = tf.placeholder(tf.int32,(None,None))
question_placeholder = tf.placeholder(tf.int32, (None,None))

extract_result = match_and_extract(embed_statements(statement_placeholder), embed_statements(question_placeholder))

sess = tf.Session()
sess.run(whitespace_repr)

def to_feed_dict(statements, questions):    
    text_repr = repr_text_batch(statements + questions)
    statement_repr, question_repr = text_repr[:len(statements)], text_repr[len(statements):]
    return {statement_placeholder:statement_repr, question_placeholder: question_repr}

sess.run((extract_result,whitespace_repr,cost_for_remainder), 
         feed_dict=to_feed_dict(statements[3:4], questions[:1]))

sess.run(embeddings[:,1])

'a asd'

In [195]:
# simple extract_or_translate unit
import quebap.projects.playqa.model as model
tf.reset_default_graph()

def tok(string):
    return string.split(" ")

statements = [tok(w) for w in [
    "eagle can fly",
    "eagle isa bird",
    "bird can fly",
    "duck can fly"
]]

questions = [tok(w) for w in [
    "eagle can _"
]]

WHITESPACE = '[WS]'
words = { word for statement in statements + questions for word in statement }
words.add(WHITESPACE)
vocab = FrozenIdentifier(words)
embeddings = tf.diag(tf.ones(len(vocab)))
whitespace_repr = tf.gather(embeddings, vocab[WHITESPACE]) # [repr_dim]
cost_for_remainder = 1 - tf.gather(embeddings, vocab['_']) - whitespace_repr

wh_token = tf.gather(embeddings, vocab['_']) # [repr_dim]
translate_token = tf.gather(embeddings, vocab['isa']) # [repr_dim]

softmax_slope = 2.0

def simple_extract_or_translate(questions, statements):
    # questions: [batch_size, length, repr_dim]
    # statements: [batch_size, length, repr_dim]
    # return: extraction result,extraction score, translation result, translation score
    # calculate total token-by-token match score
    match_score = tf.reduce_sum(questions * statements, [1,2]) # [batch_size]
    
    # find WH token
    wh_scores = tf.reduce_sum(questions * wh_token,2) # [batch_size, length]
    wh_probs = tf.nn.softmax(wh_scores) # [batch_size, length]
    
    # extract answer token
    answer_token = tf.reduce_sum(statements * tf.expand_dims(wh_probs, 2), 1) # [batch_size, repr_dim]
    
    # answer should be a sequence, append zeros
    padding = tf.zeros(tf.shape(statements) - [0, 1, 0])
    padding_ws = tf.tile(tf.expand_dims(tf.expand_dims(whitespace_repr,0),0), 
                         tf.shape(statements) * [1,1,0] + [0, -1, 1])
    answer = tf.concat(1, [tf.expand_dims(answer_token,1), padding_ws])
    
    # check for TR token at index 1
    tr_match_score = tf.reduce_sum(statements[:,1,:] * translate_token,1) # [batch_size]
    # A => Y
    lhs = statements[:,0:1,:] # [batch_size, 1, repr_dim]
    rhs = statements[:,2:3,:] 
    
    # find best match with rhs in question
    lhs_match_scores = tf.reduce_sum(questions * lhs, 2) # [batch_size, length]
    lhs_prob = tf.nn.softmax(lhs_match_scores) # [batch_size, length]
    lhs_prob_expanded = tf.expand_dims(lhs_prob, 2) # [batch_size, length, 1]
    
    # then replace with lhs
    replacement = rhs * lhs_prob_expanded 
    to_remove = lhs * lhs_prob_expanded
    new_questions = questions - to_remove + replacement
    
    # translation score
    tr_score = tf.reduce_sum(lhs_match_scores, 1) + tr_match_score
    
    # X Y  _ & 
    # A => Y 
    # X A  _ 
    # find tokens in statement that are not the TR token and aren't matched in question
    # A * *
    # find tokens in questions that are matched with the statement (assume there is a single one so use softmax)
    # * Y * 
    # remove found token from question and replace with token from statement
    # X A _
    
    return answer, match_score, new_questions, tr_score

def repr_text_batch(texts):
    max_length = max([len(text) for text in texts])
    result = [[vocab[text[i]] if i < len(text) else vocab[WHITESPACE] for i in range(0, max_length)] for text in texts]
    return result

def embed_statements(statements):
    return tf.gather(embeddings, statements)

def to_feed_dict(statements, questions):    
    text_repr = repr_text_batch(statements + questions)
    statement_repr, question_repr = text_repr[:len(statements)], text_repr[len(statements):]
    return {statement_placeholder:statement_repr, question_placeholder: question_repr}

def match_all(kb, questions):
    # kb: [kb_size, max_length, repr_dim]
    # questions [batch_size, max_length, repr_dim]
    # turn kb into [batch_size, kb_size, max_length, repr_dim]
    expanded_kb = tf.expand_dims(kb,0) 
    expanded_questions = tf.expand_dims(questions, 1)
    tiled_kb = tf.tile(expanded_kb, tf.shape(expanded_questions) * [1, 0, 0, 0] + [0, 1, 1, 1])
    tiled_questions = tf.tile(expanded_questions, tf.shape(expanded_kb) * [0, 1, 0, 0] + [1, 0, 1, 1])
    
    # now flatten 
    new_dim = tf.shape(kb)[0:1] * tf.shape(questions)[0:1]
    new_shape = tf.concat(0, [new_dim, tf.shape(kb)[1:]])
    batch_kb_shape = tf.shape(tiled_kb)[0:2]
    flat_kb = tf.reshape(tiled_kb, new_shape)
    flat_questions = tf.reshape(tiled_questions, new_shape)
    
    answers, match_scores, new_questions, tr_scores = simple_extract_or_translate(flat_questions, flat_kb)
    
    # extraction
    def aggregate_and_score(answers, match_scores):
        answers_reshaped = tf.reshape(answers, tf.shape(tiled_kb)) # [batch_size, kb_size, max_length, repr_dim]
        match_scores_reshaped = tf.reshape(match_scores, batch_kb_shape) # [batch_size, kb_size]
        match_probs = tf.nn.softmax(match_scores_reshaped * 2.0)
        match_probs_expanded = tf.expand_dims(tf.expand_dims(match_probs,2), 3) # [batch_size, kb_size, 1, 1]
        weighted_answer = tf.reduce_sum(match_probs_expanded * answers_reshaped, 1)
        global_match_score = tf.reduce_sum(match_scores_reshaped, 1)
        return weighted_answer, global_match_score
        
    
    global_answer, global_match_score = aggregate_and_score(answers, match_scores)
    global_translation, global_translation_score = aggregate_and_score(new_questions, tr_scores)
        
    return global_answer, global_match_score, global_translation, global_translation_score

def inference_steps(kb, questions, num_steps=1):
    current_questions = questions
    current_result = questions
    current_match_score = None
    current_translation_score = None
    steps = []
    for step in range(0, num_steps):
        current_result, current_match_score, global_translation, current_translation_score = match_all(kb, questions)
        # if global_match_score >> global_translation_score we should never change the question again
        prob_translate = tf.sigmoid(1.0 * (current_translation_score - current_match_score + 10.))
        current_questions = prob_translate * global_translation + (1.0 - prob_translate) * current_questions
#         current_questions = global_translation 
        steps.append((current_result, current_match_score, current_questions, current_translation_score))
        
    return steps
        

def decode(statements):
    compare_all = tf.reduce_sum(tf.expand_dims(statements,2) * embeddings,3)
    top_k = tf.nn.top_k(compare_all)

    values, indices = sess.run(top_k)

    results = []
    for seq in indices:
        sentence = []
        for token in seq:
            word = vocab.key_by_id(token[0])
            if word != WHITESPACE:
                sentence.append(vocab.key_by_id(token[0]))
        results.append(sentence)
    return results
        
statement_placeholder = tf.placeholder(tf.int32,(None,None))
question_placeholder = tf.placeholder(tf.int32, (None,None))

result = simple_extract_or_translate(embed_statements(question_placeholder), 
                                           embed_statements(statement_placeholder))

sess = tf.Session()

e_result, e_score, t_result, t_score = sess.run(result, feed_dict=to_feed_dict(statements[1:2],questions[:1]))

decode(t_result)
# e_score
# t_result

all_answer, all_match_score, all_translation, all_translation_score = sess.run(
    inference_steps(embed_statements(statement_placeholder), 
                   embed_statements(question_placeholder))[-1],
        feed_dict=to_feed_dict(statements[0:4],questions[:1]))
   
decode(all_translation), all_match_score, all_translation_score
# all_answer

([['bird', 'can', '_']],
 array([ 5.], dtype=float32),
 array([ 3.], dtype=float32))

In [51]:
sess.run(tf.gather(embeddings, vocab['[WS]'])) # [repr_dim]

def decode(statements):
    compare_all = tf.reduce_sum(tf.expand_dims(statements,2) * embeddings,3)
    top_k = tf.nn.top_k(compare_all)

    values, indices = sess.run(top_k)

    results = []
    for seq in indices:
        sentence = []
        for token in seq:
            word = vocab.key_by_id(token[0])
            if word != WHITESPACE:
                sentence.append(vocab.key_by_id(token[0]))
        results.append(sentence)
    return results

decode(e_result)

[['fly']]

In [191]:
## import numpy as np
import quebap.projects.playqa.util as pqutil
    
cell = pqutil.CompactifyCell(2,3)

inputs = tf.constant([
    [[0.5, 0, -1], [0, 2, -1]],
    [[3, 0, -1], [0, 4, -1]]
])

masks = tf.constant([[1.,1],[0,1]])

outputs, (result,counter) = tf.nn.dynamic_rnn(
        cell=cell,
        dtype=tf.float32,
        sequence_length=(2,2),
        inputs=playqa.to_inputs(inputs,masks))

sess.run((result,counter))
# sess.run(inputs[:,0,0:1])
# mask = tf.expand_dims(tf.constant([[1.,1],[1, 0]]),2) # [batch,size, max_length]
# sess.run(tf.concat(2, [mask, inputs]))

(array([[ 0.5,  0. , -1. ,  0. ,  2. , -1. ],
        [ 0. ,  4. , -1. ,  0. ,  0. ,  0. ]], dtype=float32),
 array([[ 1.,  0.],
        [ 0.,  1.]], dtype=float32))