### Role-factored Tensor Net

* Adapted from Weber et al. (2018) Event Representations with Tensor-based Compositions. AAAI18

* INPUT: a batch of events (5-tuples <v,s,o,p,po>), corresponding positive and negative instances (word2vec style training).
* OUTPUT: encoded event representation (as vectors).
* PROC:
    * Modeling the interaction between the predicate v with each of the arguments (equation: same for other args)
    $$v_s = v\cdot T\cdot s^{T}$$
    where $v\in R^{b\times d_e}$, $T\in R^{d_e,d_e,h}$, and $s\in R^{b\times d_e}$. This results in a batch of event vectors $\in R^{b,h}$ after batch matching and transposing.
    * Merging factor interactions additively
    $$e = v_s\cdot W_s + \dots + v_{po}\cdot W_{po}$$
    * Maximizing the distance between the input event and its positive examples while minimizing the distance between it and its negative examples through a max-margin loss
    $$\ell = \frac{1}{N}\sum_{i=1}^N \texttt{max}(0, m + \texttt{sim}(e, e_{neg}) - \texttt{sim}(e, e_{pos}))$$
* COMMENTS
    * For the simple demo I use dot product for distance metric rather than cosine as in the paper.
    * The network is not L2-regularized as in the paper, but the loss term can be easily added.

In [1]:
import sys
sys.path.insert(0, "/work/04233/sw33286/AIDA-SCRIPTS")

In [2]:
import os
import time
import random
import shutil
import dill
import numpy as np

import tensorflow as tf

from helpers import Indexer
from itertools import chain

### Prepare data

In [82]:
# Link to NYT data folder

nyt_code_dir = "/work/04233/sw33286/AIDA-DATA/nyt_eng_salads_event_sample_code/"
FILE_NAMES = os.listdir(nyt_code_dir)

# Link to dictionary information

info_path = "/work/04233/sw33286/AIDA-DATA/nyt_eng_salads_info/indexer_word2emb_100k.p"
indexer100k, word2emb100k = dill.load(open(info_path, 'rb'))
glove_embs = []
for i in range(len(indexer100k)):
    glove_embs.append(word2emb100k[indexer100k.get_object(i)])
glove_embs = np.array(glove_embs)
print(glove_embs.shape)

(100001, 300)


In [103]:
BATCH_SIZE = 32
CONTRA_BC = 10

def get_batch(edoc_a, edoc_b):
    edoc_a = list(chain.from_iterable(edoc_a)) # to a list of events
    edoc_b = list(chain.from_iterable(edoc_b))
    size_a, size_b = len(edoc_a), len(edoc_b)
    batch_x, batch_pos, batch_neg = [], [], []
    for _ in range(BATCH_SIZE//2):
        x_a = edoc_a[np.random.randint(0, size_a)]
        x_b = edoc_b[np.random.randint(0, size_b)]
        pos_a = [edoc_a[np.random.randint(0, size_a)] for _ in range(CONTRA_BC)]
        neg_a = [edoc_b[np.random.randint(0, size_b)] for _ in range(CONTRA_BC)]
        pos_b = [edoc_b[np.random.randint(0, size_b)] for _ in range(CONTRA_BC)]
        neg_b = [edoc_a[np.random.randint(0, size_a)] for _ in range(CONTRA_BC)]        
        batch_x += [x_a, x_b]
        batch_pos += [pos_a, pos_b]
        batch_neg += [neg_a, neg_b]
    return np.array(batch_x), np.array(batch_pos), np.array(batch_neg)

In [106]:
# Example: batch shapes

edoc_a, edoc_b, _ = dill.load(open(nyt_code_dir+FILE_NAMES[0],'rb'))
a,b1,b2 = get_batch(edoc_a, edoc_b)
a.shape, b1.shape, b2.shape

((32, 5), (32, 10, 5), (32, 10, 5))

### Role-factored Tensor Net

In [121]:
tf.reset_default_graph()

sess = tf.InteractiveSession()

VOCAB_SIZE, EMB_SIZE = glove_embs.shape
HID_SIZE = 100 # let event embs be of the same hid-size as role-factored arg vectors.

LEARNING_RATE = 1e-4

inputs = tf.placeholder(tf.int32, [BATCH_SIZE, 5], name='inputs') # <bc,nw-in-event=5>
inputs_pos = tf.placeholder(tf.int32, [BATCH_SIZE, CONTRA_BC, 5], name='inputs_pos') # <bc,ctr-bc,nw-in-event=5>
inputs_neg = tf.placeholder(tf.int32, [BATCH_SIZE, CONTRA_BC, 5], name='inputs_neg')
 
with tf.variable_scope('Embedding'):
    embeddings = tf.get_variable('embedding', [VOCAB_SIZE, EMB_SIZE],
                                 initializer=tf.contrib.layers.xavier_initializer())
    glove_init = embeddings.assign(glove_embs)

with tf.variable_scope('Role-factor'):
    T = tf.get_variable('T', [EMB_SIZE, EMB_SIZE, HID_SIZE], 
                        initializer=tf.contrib.layers.xavier_initializer())
    W_s = tf.get_variable('W_s', [HID_SIZE, HID_SIZE], initializer=tf.contrib.layers.xavier_initializer())
    W_o = tf.get_variable('W_o', [HID_SIZE, HID_SIZE], initializer=tf.contrib.layers.xavier_initializer())
    W_p = tf.get_variable('W_p', [HID_SIZE, HID_SIZE], initializer=tf.contrib.layers.xavier_initializer())
    W_po = tf.get_variable('W_po', [HID_SIZE, HID_SIZE], initializer=tf.contrib.layers.xavier_initializer())
    
def encode_events(inputs_):
    bc,_ = tf.unstack(tf.shape(inputs_))
    # Slicing inputs
    input_v = tf.squeeze(tf.slice(inputs_, [0,0],[bc,1]), -1)
        # op1. looking up the vector corresponds to the predicate: <bc,1>
        # op2. get rid of the vacuous dimension: <bc,>
    input_s = tf.squeeze(tf.slice(inputs_, [0,1],[bc,1]))
    input_o = tf.squeeze(tf.slice(inputs_, [0,2],[bc,1]))
    input_p = tf.squeeze(tf.slice(inputs_, [0,3],[bc,1]))
    input_po = tf.squeeze(tf.slice(inputs_, [0,4],[bc,1]))
    # Looking up
    input_v_embedded = tf.nn.embedding_lookup(embeddings, input_v) # <bc,emb>
    input_s_embedded = tf.transpose(tf.nn.embedding_lookup(embeddings, input_s),[1,0]) # <emb,bc>
    input_o_embedded = tf.transpose(tf.nn.embedding_lookup(embeddings, input_o),[1,0])
    input_p_embedded = tf.transpose(tf.nn.embedding_lookup(embeddings, input_p),[1,0])
    input_po_embedded = tf.transpose(tf.nn.embedding_lookup(embeddings, input_po),[1,0])
    # Role factoring
    vT = tf.transpose(tf.tensordot(input_v_embedded, T, axes=[[1],[0]]), [0,2,1])
        # op1. <bc,emb> * <emb,emb,hid> -> <bc,emb,hid>
        # op2. <bc,emb,hid> -> <bc,hid,emb>
    vTs = tf.matrix_diag_part(tf.transpose(tf.tensordot(vT, input_s_embedded, axes=[[2],[0]]), [1,0,2]))
        # op1. <bc,hid,emb> * <emb,bc> -> <bc,hid,bc>
        # op2. <bc,hid,bc> -> <hid,bc,bc>
        # op3. <hid,bc>
    vTo = tf.matrix_diag_part(tf.transpose(tf.tensordot(vT, input_o_embedded, axes=[[2],[0]]), [1,0,2]))
    vTp = tf.matrix_diag_part(tf.transpose(tf.tensordot(vT, input_p_embedded, axes=[[2],[0]]), [1,0,2]))
    vTpo = tf.matrix_diag_part(tf.transpose(tf.tensordot(vT, input_po_embedded, axes=[[2],[0]]), [1,0,2]))
    # Factor merging
    v_s = tf.matmul(W_s, vTs) # <hid,hid> * <hid,bc> -> <hid,bc>
    v_o = tf.matmul(W_p, vTo)
    v_p = tf.matmul(W_o, vTp)
    v_po = tf.matmul(W_po, vTpo)
    v_event = v_s + v_o + v_p + v_po # <hid,bc>
    return v_event

inputs_encoded = encode_events(inputs) # <hid,bc>
inputs_pos_encoded = tf.transpose(tf.map_fn(encode_events, inputs_pos, dtype=tf.float32), [0,2,1]) # <bc,hid,ctr-bc>
    # op1. event-encoder output: <bc,hid,ctr-bc>
    # op2. transpose: <bc,ctr-bc,hid>
inputs_neg_encoded = tf.transpose(tf.map_fn(encode_events, inputs_neg, dtype=tf.float32), [0,2,1])

with tf.variable_scope('Encode'):
    predictions = tf.identity(inputs_encoded, name='predictions')

with tf.variable_scope('Loss'):
    sim_pos = tf.matrix_diag_part(tf.transpose(tf.tensordot(inputs_pos_encoded, inputs_encoded, axes=[[2],[0]]), 
                                               [1,0,2]))
        # op1. tensordot: <bc,ctr-bc,hid> * <hid,bc> -> <bc,ctr-bc,bc>
        # op2. transpose: <ctr-bc,bc,bc>
        # op3. match bc: <ctr-bc,bc>
    sim_neg = tf.matrix_diag_part(tf.transpose(tf.tensordot(inputs_neg_encoded, inputs_encoded, axes=[[2],[0]]), 
                                               [1,0,2]))    
    loss = tf.reduce_mean(tf.reduce_mean(tf.maximum(0., 1. + sim_neg - sim_pos), axis=0))
        # op1. max(0, m + sim_neg - sim_pos), <ctr-bc,bc>
        # op2. average loss over contra instances: <bc,>
        # op3. average loss over batch
        
global_step = tf.Variable(0, name='global_step', trainable=False)
optimizer = tf.train.AdamOptimizer(LEARNING_RATE)
grads_and_vars = optimizer.compute_gradients(loss)
train_op = optimizer.apply_gradients(grads_and_vars, global_step=global_step, name='train_op')

sess.run(tf.global_variables_initializer())

In [122]:
NUM_EPOCHS = 10
# TRAIN_SIZE = 10
# VERBOSE = 1
TRAIN_SIZE = len(FILE_NAMES)
VERBOSE = 10

try:
    loss_track = []
    for e in range(NUM_EPOCHS):
        print('Epoch ', e+1)
        print('\n')
        curr_loss_track = []
        file_indices = np.random.choice(list(range(len(FILE_NAMES))), size=TRAIN_SIZE, replace=False)
        random.shuffle(file_indices)
        curr_loss_track, curr_accuracy_track = [], []
        for file_idx in file_indices:
            edoc_a, edoc_b, _ = dill.load(open(nyt_code_dir+FILE_NAMES[file_idx],'rb')) # context not added
            batch_x, batch_pos, batch_neg = get_batch(edoc_a, edoc_b)
            fd = {inputs:batch_x, inputs_pos:batch_pos, inputs_neg:batch_neg}
            _, step, loss_ = sess.run([train_op, global_step, loss], feed_dict=fd)
            curr_loss_track.append(loss_)
            if step%VERBOSE==0:
                print(' average batch loss at step {}: <{}>'.format(step, np.mean(curr_loss_track)))
        print('\n')
        print('  epoch mean loss: <{}>'.format(np.mean(curr_loss_track)))
        print('\n') 
        loss_track += curr_loss_track  
except KeyboardInterrupt:
    print('Stopped!')                      

Epoch  1


 average batch loss at step 10: <1.0>
 average batch loss at step 20: <1.0>
 average batch loss at step 30: <1.0>
 average batch loss at step 40: <1.0>
 average batch loss at step 50: <1.0>
 average batch loss at step 60: <1.0>
 average batch loss at step 70: <1.0>
 average batch loss at step 80: <1.0>
 average batch loss at step 90: <1.0>
 average batch loss at step 100: <0.9999977350234985>


  epoch mean loss: <0.9999977350234985>


Epoch  2


 average batch loss at step 110: <0.9999059438705444>
 average batch loss at step 120: <0.9979122877120972>
 average batch loss at step 130: <0.9878673553466797>
 average batch loss at step 140: <0.9838888049125671>
 average batch loss at step 150: <0.982790470123291>
 average batch loss at step 160: <0.9817882776260376>
 average batch loss at step 170: <0.9811098575592041>
 average batch loss at step 180: <0.9811125993728638>
 average batch loss at step 190: <0.9831833839416504>
 average batch loss at step 200: <0.9843214154243469>