# Recursive Neural Networks in Tensorflow

My attempt to implement a more robust NN class in Tensorflow, but also to make a recursive (tree structured) network that I haven't seen an example online of yet.

In [14]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
% matplotlib inline
import re
from spacy.en import English
nlp = English()

In [8]:
class TreeRNN(object):
    """A Tree Structured Recursive Neural Net, with variable number of children.
    
    Uses a sklearn style interface
    """
    def __init__(self):
        # network options
        self.num_classes = 3
        self.embedding_size = 300
        
        
        self._create_variables()
        
#         self.saver = tf.Saver()
        
        # op to initialize all the variables
        init = tf.initialize_all_variables()
        
        # create a running session and start up the graph
        # training and predicting with the graph is done from outside
        self.session = tf.InteractiveSession()
        self.session.run(init)
    
    def _create_variables(self):
        """Initialize the network weights"""
        # recurrent unit weights and biases
        self.w_input = tf.Variable(np.I(self.embedding_size))
#         self.b_input = tf.Variable(tf.zeros(self.embedding_size))
        
        self.w_state = tf.Variable(np.I(self.embedding_size))
        self.b_state = tf.Variable(tf.zeros(self.embedding_size))
        
        # word embedding matrix
        self.word_embeddings = tf.Variable(tf.truncated_normal(
                                           [self.embedding_size, self.embedding_size],
                                           mean=0.0, stddev=1.0,
                                           dtype=tf.float32,
                                           random_state=0, ))
        
        # softmax weights and biases
        self.softmax_w = tf.Variable(tf.random_uniform(
                                     [self.num_classes, self.embedding_size],
                                     min=-.5, max=.5,
                                     dtype=tf.float32,
                                     random_state=0))
        self.softmax_b = tf.Variable(tf.zeros(self.num_classes))
        
    def _create_optimizer(self):
        """Initialize the optimization ops (losses will be computed dynamically)"""
        self.optimizer = tf.train.AdagradOptimizer(1.0)
        
    def _recurrent_cell(self, input_vector, child_states):
        """Take the input vector and a list of input states (each assumed to be same size) 
        and compute the sum rnn output
        """
        state_transforms = [ tf.matmul(self.w_state, input_state) for input_state in child_states ]
        child_sum = tf.add_n(state_transforms)
        inner_sum = tf.matmul(self.w_input, input_vector) + child_sum + self.b_state
        output_vector = tf.tanh(inner_sum)
        return output_vector
        
    def _forward_pass(self, node):
        loss = 0
        # compute this recursively
        for child in node.children:
            _, child_loss = self._forward_pass(child)
            loss += child_loss
        
        # now were back at _this_ node
        # get the vector for this node's word
        word_vec = tf.nn.embedding_lookup(self.word_embeddings, node.word_index)
        # get the states for all the children
        if node.children:
            child_states = [ child.state for child in node.children ]
        else:
            child_states = [ tf.zeros(embedding_size)]
        
        # compute the rnn over this node
        node.state = self._recurrent_cell(word_vec, child_states)
        
        # measure loss as classification cross entropy for softmax layer
        logits = tf.matmul(self.softmax_w, node.state) + self.softmax_b
        true = tf.zeros(self.num_classes, dtype=tf.float32)
        true[node.class_index] = 1.0
        cross_entropy = tf.nn.softmx_cross_entropy_with_logits(true, logits)
#         cross_entropy = -tf.reduce_sum(tf.mul(true, tf.log(1e-10 + logit)))
        
        node.cross_entropy = cross_entropy
        node.loss = cross_entropy + loss
        
        # loss for each node also includes loss from descendants
        return node.state, node.loss
    
    def partial_fit(self, tree_batch):
        batch_size = len(tree_batch)
        losses = []
        for tree in tree_batch:
            _, loss = self._forward_pass(tree)
            losses.append(loss)
            
        objective = ((1.0/batch_size) * np.sum(losses) 
                     + self.lambda_reg * (
                           tf.reduce_sum(self.softmax_w**2)
                           + tf.reduce_sum(self.softmax_b**2)))
    
        # we need to perform gradient clipping explicitly
        grads_and_vars = self.optimizer.compute_gradients(objective, 
                                                     [self.word_embeddings,
                                                      self.w_input,
                                                      self.w_state,
                                                      self.b_state,
                                                      self.word_embeddings,
                                                      self.softmax_w,
                                                      self.softmax_b])
        clipped_grads_and_vars = [(tf.clip_by_norm(gv[0], 1.), gv[1]) for gv in grads_and_vars]
        minimizer = self.optimizer.apply_gradients(clipped_grads_and_vars)
        print "Avg perplexity: %0.4f" % np.exp(np.sum(losses)/len(losses))

In [9]:
class Node(object):
    """Node for defining trees with adjacency lists"""
    def __init__(self):
        self.head = None
        self.children = []
        self.loss = None
        self.state = None
        self.cross_entropy = None

In [10]:
split_delims = [' ',',','.',';',':', '%', '"', '$', '^']
def split(string, delimiters=split_delims, maxsplit=0):
    regexPattern = '|'.join(map(re.escape, delimiters))
    return re.split(regexPattern, string, maxsplit)

def convert_raw_x(line):
    """Convert raw line of semeval data into a useable form
    
    Convert to a triple of (list(raw words), e1_index, e2_index)
    """
    s = line.strip()
    s = s[s.index('"')+1: -(s[::-1].index('"')+1)] # get s between first " and last "
    # we will assume that the first token follow the <e1> , <e2> tags are the entity words.  
    # note this is a big assumption and hopefully phrases will be in subtrees or in heads of the parse trees
    # TODO: this can be addressed by making it a 5-tuple with the endpoints also encoded
    s = split(s)
    for i in range(len(s)):
        # deal with e1's
        if '<e1>' in s[i]:
            e1_index = i
            s[i] = s[i].replace('<e1>', '')
        if '</e1>' in s[i]:
            #e1_index = i
            s[i] = s[i].replace('</e1>', '')
        # eal with e2's
        if '<e2>' in s[i]:
            e2_index = i
            s[i] = s[i].replace('<e2>', '')
        if '</e2>' in s[i]:
            #e2_index = i
            s[i] = s[i].replace('</e2>', '')
        
    # drop extraneous elements from re.split
    # also turn it into a spacy sentence
    s = nlp(u' '.join([ w.lower() for w in s if w is not '' ])) 
    return (s, e1_index, e2_index)
    
label2int = dict() # keep running dictionary of labels
def convert_raw_y(line):
    """Convert raw line of semeval labels into a useable form (ints)"""
    #print "Raw Y: %r" % line[:]
    line = line.strip()
    if line in label2int:
        return label2int[line]
    else:
        label2int[line] = len(label2int.keys())
        return label2int[line]

In [16]:
def load_semeval_data():
    training_txt_file = 'SemEval2010_task8_all_data/SemEval2010_task8_training/TRAIN_FILE.TXT'
    validation_index = 8000 - 891# len data - len valid - 1 since we start at 0
    train = {'x':[], 'y':[]}
    valid = {'x':[], 'y':[]}
    text = open(training_txt_file, 'r').readlines()
    assert len(text) // 4 == 8000
    for cursor in range(len(text) // 4): # each 4 lines is a datum
        if cursor < validation_index:
            train['x'].append(convert_raw_x(text[cursor*4]))
            train['y'].append(convert_raw_y(text[cursor*4 + 1]))
            # ignore comments and blanks (+2, +3)
        else:
            valid['x'].append(convert_raw_x(text[cursor*4]))
            valid['y'].append(convert_raw_y(text[cursor*4 + 1]))

    #print train
    #print label2int.values()
    assert len(train['y']) == 7109 and len(valid['y']) == 891
    assert sorted(label2int.values()) == range(19) # 2 for each 9 asymmetric relations and 1 other
    
    return train, valid
    
semeval_train, semeval_valid = load_semeval_data()

In [None]:
def convert_semeval_to_sample(semeval_datum):
    