In [None]:
import ipywidgets as widgets
import tensorflow as tf
import uuid
from IPython import display
from dragnn.protos import spec_pb2
from dragnn.python import graph_builder
from dragnn.python import load_dragnn_cc_impl  # This loads the actual op definitions
from dragnn.python import render_parse_tree_graphviz
from dragnn.python import visualization
from google.protobuf import text_format
from syntaxnet import load_parser_ops  # This loads the actual op definitions
from syntaxnet import sentence_pb2
from tensorflow.python.platform import tf_logging as logging

# Read the master spec
master_spec = spec_pb2.MasterSpec()
with open("data/master_spec_es.prototext", "r") as f:
    text_format.Merge(f.read(), master_spec)

logging.set_verbosity(logging.WARN)  # Turn off TensorFlow spam.

# Initialize a graph
graph = tf.Graph()
with graph.as_default():
    hyperparam_config = spec_pb2.GridPoint()
    builder = graph_builder.MasterBuilder(master_spec, hyperparam_config)
    # This is the component that will annotate test sentences.
    annotator = builder.add_annotation(enable_tracing=True)
    builder.add_saver()  # "Savers" can save and load models; here, we're only going to load.

In [None]:
sess = tf.Session(graph=graph)
with graph.as_default():
    sess.run(tf.global_variables_initializer())
    sess.run('save/restore_all', {'save/Const:0': "data/mini-spanish.checkpoint"})

def annotate_text(text):
    sentence = sentence_pb2.Sentence(
        token=[sentence_pb2.Token(word=word, start=-1, end=-1) for word in text.split()]
    )
    with graph.as_default():
        annotations, traces = sess.run([annotator['annotations'], annotator['traces']],
                                       feed_dict={annotator['input_batch']: [sentence.SerializeToString()]})
        assert len(annotations) == 1
        assert len(traces) == 1
        return sentence_pb2.Sentence.FromString(annotations[0]), traces[0]
annotate_text("casa"); None  # just make sure it works

# Interactive trace explorer
Run the cell below, and then enter text in the interactive widget.

In [None]:
def _trace_explorer():  # put stuff in a function to not pollute global scope
    text = widgets.Text()
    display.display(text)

    output = visualization.InteractiveVisualization()
    display.display(display.HTML(output.initial_html()))

    def handle_submit(sender):
        del sender  # unused
        parse_tree, trace = annotate_text(text.value)
        display.display(display.HTML(output.show_trace(trace)))


    text.on_submit(handle_submit)
_trace_explorer()

# Interactive parse tree explorer
Run the cell below, and then enter text in the interactive widget.

In [None]:
def _parse_tree_explorer():  # put stuff in a function to not pollute global scope
    text = widgets.Text()
    display.display(text)
    html = widgets.HTML()
    display.display(html)

    def handle_submit(sender):
        del sender  # unused
        parse_tree, trace = annotate_text(text.value)
        html.value = render_parse_tree_graphviz.parse_tree_graph(parse_tree)

    text.on_submit(handle_submit)
_parse_tree_explorer()