In [95]:
import os, sys, shutil, time, itertools
import math, random
from collections import OrderedDict, defaultdict

import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

from collections import defaultdict
from tqdm import tqdm

class Vocab(object):
  def __init__(self):
    self.word_to_index = {}
    self.index_to_word = {}
    self.unknown = '<unk>'
    self.add_word(self.unknown, count=0)

  def add_word(self, word, count=1):
    if word not in self.word_to_index:
      index = len(self.word_to_index)
      self.word_to_index[word] = index
      self.index_to_word[index] = word

  def build_vocab(self, words):
    for word in words:
      self.add_word(word)
    print('{} total unique words'.format(len(self.word_to_index)))

  def encode(self, tokens):
    if type(tokens) is str:
      tokens = [tokens]
    ids = []
    for token in tokens:
      if token not in self.word_to_index:
        token = self.unknown
      ids.append(self.word_to_index[token])

    return ids

  def decode(self, ids):
    if type(ids) is int:
      ids = [id]
    tokens = []
    for id in ids:
     tokens.append(self.index_to_word[id])

    return tokens

  def __len__(self):
    return len(self.word_freq)

  def save(self,vocab_file_name):
    save_dic = {
        'word_to_index': self.word_to_index,
        'index_to_word':self.index_to_word,
    }
    np.save(vocab_file_name, save_dic)
    
  def load(self,vocab_file_name):
    loaded_dic = np.load(vocab_file_name+".npy").item()
    self.word_to_index = loaded_dic['word_to_index']
    self.index_to_word = loaded_dic['index_to_word']
    
class Node:  # a node in the tree
  def __init__(self, label, word=None):
    self.label = label
    self.word = word
    self.parent = None  # reference to parent
    self.left = None  # reference to left child
    self.right = None  # reference to right child
    # true if I am a leaf (could have probably derived this from if I have
    # a word)
    self.isLeaf = False
    # true if we have finished performing fowardprop on this node (note,
    # there are many ways to implement the recursion.. some might not
    # require this flag)

  def __str__(self):
    if self.isLeaf:
      return '[{0}:{1}]'.format(self.word, self.label)
    return '({0} <- [{1}:{2}] -> {3})'.format(self.left, self.word, self.label, self.right)


class Tree:

  def __init__(self, treeString, openChar='(', closeChar=')'):
    tokens = []
    self.open = '('
    self.close = ')'
    for toks in treeString.strip().split():
      tokens += list(toks)
    self.root = self.parse(tokens)
    # get list of labels as obtained through a post-order traversal
    self.labels = get_labels(self.root)
    self.num_words = len(self.labels)

  def parse(self, tokens, parent=None):
    assert tokens[0] == self.open, "Malformed tree"
    assert tokens[-1] == self.close, "Malformed tree"

    split = 2  # position after open and label
    countOpen = countClose = 0

    if tokens[split] == self.open:
      countOpen += 1
      split += 1
    # Find where left child and right child split
    while countOpen != countClose:
      if tokens[split] == self.open:
        countOpen += 1
      if tokens[split] == self.close:
        countClose += 1
      split += 1

    # New node
    node = Node(int(tokens[1]))  # zero index labels

    node.parent = parent

    # leaf Node
    if countOpen == 0:
      node.word = ''.join(tokens[2:-1]).lower()  # lower case?
      node.isLeaf = True
      return node

    node.left = self.parse(tokens[2:split], parent=node)
    node.right = self.parse(tokens[split:-1], parent=node)

    return node

  def get_words(self):
    leaves = getLeaves(self.root)
    words = [node.word for node in leaves]
    return words


def leftTraverse(node, nodeFn=None, args=None):
  """
  Recursive function traverses tree
  from left to right.
  Calls nodeFn at each node
  """
  if node is None:
    return
  leftTraverse(node.left, nodeFn, args)
  leftTraverse(node.right, nodeFn, args)
  nodeFn(node, args)


def getLeaves(node):
  if node is None:
    return []
  if node.isLeaf:
    return [node]
  else:
    return getLeaves(node.left) + getLeaves(node.right)


def get_labels(node):
  if node is None:
    return []
  return get_labels(node.left) + get_labels(node.right) + [node.label]


def clearFprop(node, words):
  node.fprop = False



In [96]:
dataSet = "train"
file = 'trees/%s.txt' % dataSet
print("Loading %s trees.." % dataSet)
with open(file, 'r') as fid:
    trees = [Tree(l) for l in fid.readlines()]
    

Loading train trees..


In [97]:
vocab = Vocab()
# Get list of tokenized sentences
train_sents = [t.get_words() for t in trees]
# Get list of all words
all_words = list(itertools.chain.from_iterable(train_sents))

# Build Vocab
vocab.build_vocab(all_words)
vocab.save('sst_vocab')

16581 total unique words


In [98]:
vocab.load('sst_vocab')

In [136]:
example_features = []
for example_id, tree in enumerate(trees):
    words = tree.get_words()
    node  = tree.root
    nodes_list = []
    leftTraverse(node, lambda node, args: args.append(node), nodes_list)
    node_to_index = OrderedDict()
    for i in range(len(nodes_list)):
      node_to_index[nodes_list[i]] = i
    example_features.append({
      'example_id': example_id,
      'is_leaf': [int(node.isLeaf) for node in nodes_list],
      'left_children': [node_to_index[node.left] if
                                         not node.isLeaf else -1
                                         for node in nodes_list],
      'right_children': [node_to_index[node.right] if
                                          not node.isLeaf else -1
                                          for node in nodes_list],
      'node_word_ids': [vocab.encode(node.word)[0] if
                                             node.word else -1
                                             for node in nodes_list],
      'labels': [node.label for node in nodes_list],
      'length': len(nodes_list)
    })

In [156]:
def get_tf_features(example_feaures):
    """Convert our own representation of an example's features to Features class for TensorFlow dataset.
    """
    features = tf.train.Features(feature={
        "example_id": tf.train.Feature(int64_list=tf.train.Int64List(value=[example_feaures['example_id']])),
        "is_leaf": tf.train.Feature(int64_list=tf.train.Int64List(value=example_feaures['is_leaf'])),
        "left_children": tf.train.Feature(int64_list=tf.train.Int64List(value=example_feaures['left_children'])),
        "right_children": tf.train.Feature(int64_list=tf.train.Int64List(value=example_feaures['right_children'])),
        "node_word_ids": tf.train.Feature(int64_list=tf.train.Int64List(value=example_feaures['node_word_ids'])),
        "labels": tf.train.Feature(int64_list=tf.train.Int64List(value=example_feaures['labels'])),
        "length": tf.train.Feature(int64_list=tf.train.Int64List(value=[example_feaures['length']])),
      })
    return features

tf_example_features = []
for example in example_features:
       tf_example_features.append(get_tf_features(example))

In [157]:
with tf.python_io.TFRecordWriter('train_trees') as tf_record_writer:
    for example in tqdm(tf_example_features):
        tf_record = tf.train.Example(features=example)
        tf_record_writer.write(tf_record.SerializeToString())

100%|██████████| 8544/8544 [00:00<00:00, 9859.53it/s] 


In [201]:
def parse_sst_tree_examples(example):
    """Load an example from TF record format."""
    features = {"example_id": tf.FixedLenFeature([], tf.int64),
                "length": tf.FixedLenFeature([], tf.int64),
               "is_leaf": tf.FixedLenSequenceFeature([], tf.int64, allow_missing=True),
               "left_children": tf.FixedLenSequenceFeature([], tf.int64, allow_missing=True),
               "right_children": tf.FixedLenSequenceFeature([], tf.int64, allow_missing=True),
               "node_word_ids": tf.FixedLenSequenceFeature([], tf.int64, allow_missing=True),
               "labels": tf.FixedLenSequenceFeature([], tf.int64, allow_missing=True)}
    parsed_example = tf.parse_single_example(example, features=features)
    
    example_id = parsed_example["example_id"]
    length = parsed_example["length"]
    tf.logging.info(length)
    is_leaf = parsed_example["is_leaf"]
    tf.logging.info(is_leaf)
    left_children = parsed_example["left_children"]
    right_children = parsed_example["right_children"]
    node_word_ids = parsed_example["node_word_ids"]
    labels = parsed_example["labels"]

    return example_id, length, is_leaf, left_children, right_children, node_word_ids, labels
                
dataset = tf.data.TFRecordDataset("train_trees")
dataset = dataset.map(parse_sst_tree_examples)
dataset = dataset.padded_batch(10, padded_shapes=([],[],[None],[None],[None],[None],[None]))
iterator = dataset.make_initializable_iterator()

example_id, length, is_leaf, left_children, right_children, node_word_ids, labels = iterator.get_next()

INFO:tensorflow:Tensor("ParseSingleExample/ParseSingleExample:4", shape=(), dtype=int64)
INFO:tensorflow:Tensor("ParseSingleExample/ParseSingleExample:1", shape=(?,), dtype=int64)


In [202]:
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())
    sess.run(iterator.initializer)
    print(sess.run(is_leaf))

[[1 1 0 1 1 1 1 1 1 1 1 0 1 1 1 0 0 0 0 0 0 0 1 0 1 0 1 1 1 1 1 1 1 1 0 1
  1 0 0 0 1 1 1 0 1 0 1 1 1 0 0 0 1 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0
  0 0 0 0 0]
 [1 1 1 1 0 0 0 1 1 0 1 1 1 1 1 0 0 0 1 1 0 0 0 0 0 1 1 1 0 0 1 1 1 0 1 1
  0 0 1 1 0 1 0 1 1 1 1 1 0 0 0 1 1 0 0 1 1 1 1 0 0 1 0 0 0 0 0 0 0 0 1 0
  0 0 0 0 0]
 [1 1 1 0 0 1 1 1 0 1 1 0 0 0 0 1 1 1 1 0 0 1 1 0 1 1 0 1 1 1 0 0 1 1 1 0
  0 0 0 0 0 1 0 0 0 1 0 1 1 1 0 0 1 1 1 1 1 1 1 0 0 0 1 0 0 1 1 1 0 0 0 0
  0 0 0 1 0]
 [1 1 1 1 1 0 0 0 0 1 1 1 1 1 1 1 1 1 0 0 0 0 0 1 1 1 1 0 0 0 0 0 0 1 0 0
  0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
  0 0 0 0 0]
 [1 1 1 0 1 1 1 0 0 1 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
  0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
  0 0 0 0 0]
 [1 1 1 0 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 1 1 1 1 1 0 1 0 1 0 1 0 1 1 0 0
  0 0 0 0 0 1 1 1 1 1 1 1 1 0 1 0 0 1 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0
  0 0 0 0 0]
 [1 1 1 0 0 1 1 1 1 1 1 0 0 1 1 1 