In [1]:
import numpy as np
import numpy.random as npr
import json
import io
import time
import os.path
import six
import copy

import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sb
sb.set_color_codes()
plt.rcParams['figure.figsize'] = (6,4)

import chainer as ch
import chainer.training.extensions # for some reason this isn't automatically imported

In [2]:
# allow for importing local infonet package
import sys
import os
sys.path.append(os.path.abspath('..'))
# import sys
# module_path = os.path.abspath(os.path.join('..'))
# if module_path not in sys.path:
#     sys.path.append(module_path)

In [33]:
from infonet.vocab import Vocab
from infonet.preprocess import compute_flat_mention_labels, Entity_BIO_map
from infonet.util import convert_sequences, SequenceIterator




# Loading data

In [9]:
data = json.loads(io.open('../data/ace_05_head_yaat.json', 'r').read())

In [10]:
# get vocabs
token_vocab = Vocab(min_count=5)
for doc in data.values():
    token_vocab.add(doc['tokens'])

boundary_vocab = Vocab(min_count=0)
for doc in data.values():
    doc['boundary_labels'] = compute_flat_mention_labels(doc, Entity_BIO_map)
    boundary_vocab.add(doc['boundary_labels'])

In [11]:
# print data.values()[0]['annotations']
def compute_mentions(doc, fine_grained=False):
    mentions = []
    for ann in doc['annotations']:
        if ann['ann-type'] == u'node':
            mention_type = ann['node-type']+':'+ann['type']
            if fine_grained:
                mention_type += ':'+ann['subtype']
            mentions.append((ann['ann-span'][0], ann['ann-span'][1], mention_type))
    return mentions
print compute_mentions(data.values()[0])

def compute_relations(doc, fine_grained=False):
    relations = []
    id2ann = { ann['ann-uid']:ann for ann in doc['annotations']}
    for ann in doc['annotations']:
        if ann['ann-type'] == u'edge':
            left_span = id2ann[ann['ann-left']]['ann-span']
            right_span = id2ann[ann['ann-right']]['ann-span']
            rel_type = ann['edge-type']+':'+ann['type']
            if fine_grained:
                rel_type += ':'+ann['subtype']
            relations.append((left_span[0], left_span[1], right_span[0], right_span[1], rel_type))
    return relations
# print compute_relations(data.values()[0])

mention_vocab = Vocab(min_count=0)
relation_vocab = Vocab(min_count=0)
for doc in data.values():
    doc['mentions'] = compute_mentions(doc)
    mention_vocab.add([ m[2] for m in doc['mentions'] ])
    doc['relations'] = compute_relations(doc)
    relation_vocab.add([r[4] for r in doc['relations']])    

[(23, 28, u'value:TIME'), (62, 67, u'value:TIME'), (156, 158, u'value:Numeric'), (181, 183, u'value:Numeric'), (230, 231, u'entity:GPE'), (240, 241, u'entity:GPE'), (250, 251, u'entity:GPE'), (291, 292, u'entity:ORG'), (381, 382, u'entity:ORG'), (293, 294, u'entity:GPE'), (346, 347, u'entity:GPE'), (361, 362, u'entity:PER'), (379, 380, u'entity:GPE'), (75, 76, u'entity:ORG'), (92, 93, u'entity:ORG'), (55, 56, u'entity:PER'), (90, 91, u'entity:PER'), (374, 375, u'entity:PER'), (131, 132, u'entity:PER'), (347, 348, u'entity:PER'), (136, 137, u'entity:GPE'), (206, 207, u'entity:GPE'), (212, 213, u'entity:GPE'), (225, 226, u'entity:GPE'), (270, 271, u'entity:GPE'), (306, 307, u'entity:GPE'), (335, 336, u'entity:GPE'), (354, 355, u'entity:GPE'), (150, 151, u'entity:GPE'), (151, 152, u'entity:PER'), (153, 154, u'entity:PER'), (332, 333, u'entity:PER'), (163, 164, u'entity:PER'), (173, 174, u'entity:GPE'), (208, 209, u'entity:GPE')]


In [12]:
print mention_vocab.v, mention_vocab.n
print mention_vocab.vocabset
print '-'*50
print relation_vocab.v, relation_vocab.n
print relation_vocab.vocabset

23 53713
set([u'entity:WEA', u'value:Job-Title', u'entity:FAC', u'value:Crime', u'entity:PER', u'entity:VEH', u'value:Sentence', u'entity:LOC', u'value:Numeric', u'event-anchor:Conflict', u'event-anchor:Contact', '<UNK>', u'event-anchor:Life', u'value:TIME', '<PAD>', u'value:Contact-Info', u'event-anchor:Personnel', u'event-anchor:Movement', u'event-anchor:Business', u'entity:ORG', u'entity:GPE', u'event-anchor:Transaction', u'event-anchor:Justice'])
--------------------------------------------------
71 235619
set([u'relation:--PHYS->', u'event-argument:<-ARG:Beneficiary--', u'event-argument:--ARG:Recipient->', u'event-argument:<-ARG:Person--', u'relation:--PER-SOC->', u'event-argument:--ARG:Sentence->', u'event-argument:--ARG:Entity->', u'relation:<-PHYS--', u'event-argument:--ARG:Destination->', u'event-argument:--ARG:Artifact->', u'event-argument:--ARG:Position->', u'event-argument:--ARG:Person->', u'event-argument:--ARG:Plaintiff->', u'event-argument:<-ARG:Entity--', u'coreference:

In [13]:
# create datasets
test = .1
valid = .1
dataset = [(doc['tokens'], doc['boundary_labels'], doc['mentions'], doc['relations']) for doc in data.values()]
npr.shuffle(dataset)
test = 1-test # eg, .1 -> .9
valid = test-valid # eg, .1 -> .9
valid_split = int(len(dataset)*valid)
test_split = int(len(dataset)*test)
dataset_train, dataset_valid, dataset_test = (dataset[:valid_split], 
                                              dataset[valid_split:test_split], 
                                              dataset[test_split:])

x_train = [d[0] for d in dataset_train]
b_train = [d[1] for d in dataset_train]
m_train = [d[2] for d in dataset_train]
r_train = [d[3] for d in dataset_train]

x_valid = [d[0] for d in dataset_valid]
b_valid = [d[1] for d in dataset_valid]
m_valid = [d[2] for d in dataset_valid]
r_valid = [d[3] for d in dataset_valid]

x_test = [d[0] for d in dataset_test]
b_test = [d[1] for d in dataset_test]
m_test = [d[2] for d in dataset_test]
r_test = [d[3] for d in dataset_test]

print '{} train, {} validation, and {} test documents'.format(len(x_train), len(x_valid), len(x_test))

428 train, 53 validation, and 54 test documents


In [31]:
print (1,2)+tuple('a')
print (0,)

(1, 2, 'a')
(0,)


In [35]:
batch_size = 64
# convert dataset to idxs
# before we do conversions, we need to drop unfrequent words from the vocab and reindex it
print "Setting up...",
token_vocab.drop_infrequent()
boundary_vocab.drop_infrequent()
mention_vocab.drop_infrequent()
relation_vocab.drop_infrequent()

ix_train = convert_sequences(x_train, token_vocab.idx)
ix_valid = convert_sequences(x_valid, token_vocab.idx)
ix_test = convert_sequences(x_test, token_vocab.idx)
ib_train = convert_sequences(b_train, boundary_vocab.idx)
ib_valid = convert_sequences(b_valid, boundary_vocab.idx)
ib_test = convert_sequences(b_test, boundary_vocab.idx)
convert_mention = lambda x:x[:-1]+(mention_vocab.idx(x[-1]),)
im_train = convert_sequences(m_train, convert_mention)
im_valid = convert_sequences(m_valid, convert_mention)
im_test = convert_sequences(m_test, convert_mention)
convert_relation = lambda x:x[:-1]+(relation_vocab.idx(x[-1]),)
ir_train = convert_sequences(r_train, convert_relation)
ir_valid = convert_sequences(r_valid, convert_relation)
ir_test = convert_sequences(r_test, convert_relation)

# data
train_iter = SequenceIterator(zip(ix_train, ib_train, im_train, ir_train), batch_size, repeat=True)
valid_iter = SequenceIterator(zip(ix_valid, ib_valid, im_valid, ir_valid), batch_size, repeat=True)
print "Done"

Setting up... Done


# Building Model

In [8]:
from infonet.tagger import Tagger, extract_all_mentions
class Extractor(ch.Chain):
    def __init__(self, 
                 feature_size,
                 n_mention_class,
                 n_relation_class,
                 in_tags=(1,2), 
                 out_tags=(0,),
                 max_rel_dist=1000):
        super(Extractor, self).__init__(
            f_m=ch.links.Linear(feature_size, n_mention_class),
            f_r=ch.links.Linear(2*feature_size, n_relation_class)
        )
#         self.f_m.W.data = npr.randint((feature_size, n_mention_class))
        self.in_tags = in_tags
        self.out_tags = out_tags
        self.max_rel_dist = max_rel_dist
        
    def _extract_graph(self, tagger_preds, tagger_features):
        # convert from time-major to batch-major
        tagger_preds = ch.functions.transpose_sequence(tagger_preds)
        tagger_features = ch.functions.transpose_sequence(tagger_features)
#         print len(tagger_preds)
        
        # extract the mentions and relations for each doc
        all_boundaries = extract_all_mentions(tagger_preds, 
                                              in_tags=self.in_tags, 
                                              out_tags=self.out_tags)
#         print 'N', [len(b) for b in all_boundaries]
        all_mentions = []
        all_relation_idxs = []
        all_left_mentions = []
        all_right_mentions = []
        for s, (boundaries, seq, features) in enumerate(zip(all_boundaries, tagger_preds, tagger_features)):
            mentions = []
            relation_idxs = []
            left_mentions = []
            right_mentions = []
            for i, b in enumerate(boundaries):
                mention = ch.functions.sum(features[b[0]:b[1]], axis=0)
                mentions.append(mention)
                # make a relation of to all previous mentions (M choose 2)
                for j in range(i):
                    if abs(boundaries[j][0] - b[0]) < self.max_rel_dist:
                        print '\rDoc: {} R({},{})'.format(s, j, i),
                        relation_idxs.append([j,i])
                        left_mentions.append(mentions[j])
                        right_mentions.append(mentions[i])
            print 'Stacking...',
            if mentions:
                mentions = ch.functions.vstack(mentions)
                print 'm=',mentions.shape,
            all_mentions.append(mentions)
            if left_mentions:
                left_mentions = ch.functions.vstack(left_mentions)
                print 'l=',left_mentions.shape,
            all_left_mentions.append(left_mentions)
            if right_mentions:
                right_mentions = ch.functions.vstack(right_mentions)
                print 'r=',right_mentions.shape,
            all_right_mentions.append(right_mentions)
            all_relation_idxs.append(relation_idxs)
            print 'Done'
        return all_mentions, all_left_mentions, all_right_mentions, all_relation_idxs
    
    def __call__(self, tagger_logits, tagger_features):
        tagger_preds = [ ch.functions.argmax(logit, axis=1) for logit in tagger_logits ]
        start = time.time()
        mentions, l_mentions, r_mentions, rel_idxs = self._extract_graph(    
            tagger_preds, 
            tagger_features)
        # concat left and right mentions into one relation vector
        relations = [ ch.functions.concat(m, axis=1) if type(m[0]) is ch.Variable else m[0]
                             for m in zip(l_mentions, r_mentions) ]
        # score mentions and relations
#         print mentions[0].shape
        m_logits = [ self.f_m(m) if type(m) is ch.Variable else []
                     for m in mentions ]
        r_logits = [ self.f_r(r) if type(r) is ch.Variable else []
                     for r in relations ]
        return m_logits, r_logits, rel_idxs

# Training Model