In [1]:
import pickle 

def atisfold(fold):
    assert fold in range(5)
    f = PREFIX + 'atis.fold'+str(fold)+'.pkl'
    train_set, valid_set, test_set, dicts = pickle.load(open(f, 'rb'), encoding='bytes')
    return train_set, valid_set, test_set, dicts


In [2]:
PREFIX = 'dataset/'
w2ne, w2la = {}, {}
train, _, test, dic = atisfold(1)
w2idx, ne2idx, labels2idx = dic[b'words2idx'], dic[b'tables2idx'], dic[b'labels2idx']

In [3]:
print(dic.keys())

dict_keys([b'words2idx', b'tables2idx', b'labels2idx'])


In [4]:
def get_entities(labels):
    idx = 0
    last_begin = -1
    entity = ""
    entities = []
    while idx < len(labels):
        if labels[idx].startswith('B'):
            last_begin = idx;
            entity = labels[idx][2:]
        elif labels[idx].startswith('O'):
            if last_begin > 0:
                entities.append((last_begin, idx, entity))
                last_begin = -1
        idx += 1
    if last_begin > 0:
        entities.append((last_begin, idx, entity))

    return entities

from mitie import *
idx2w  = dict((v,k) for k,v in w2idx.items())
idx2ne = dict((v,k) for k,v in ne2idx.items())
idx2la = dict((v,k) for k,v in labels2idx.items())

test_x,  test_ne,  test_label  = test
train_x, train_ne, train_label = train
trainer = ner_trainer("../MITIE-models/english/total_word_feature_extractor.dat")

output = 0
for sentence_a, label_a in zip(train_x, train_label):
    instance = [idx2w[word].decode('utf8') for word in sentence_a]
    labels = [idx2la[label].decode('utf8') for label in label_a]
    sample = ner_training_instance(instance)
    print(instance)
    print(labels)
    print()
    for entity in get_entities(labels):
        sample.add_entity(xrange(entity[0], entity[1]), entity[2])
    trainer.add(sample)
    output += 1
    if output > 5: break

['what', 'aircraft', 'is', 'used', 'on', 'delta', 'flight', 'DIGITDIGITDIGITDIGIT', 'from', 'kansas', 'city', 'to', 'salt', 'lake', 'city']
['O', 'O', 'O', 'O', 'O', 'B-airline_name', 'O', 'B-flight_number', 'O', 'B-fromloc.city_name', 'I-fromloc.city_name', 'O', 'B-toloc.city_name', 'I-toloc.city_name', 'I-toloc.city_name']

['i', 'want', 'to', 'go', 'from', 'boston', 'to', 'atlanta', 'on', 'monday']
['O', 'O', 'O', 'O', 'O', 'B-fromloc.city_name', 'O', 'B-toloc.city_name', 'O', 'B-depart_date.day_name']

['i', 'need', 'a', 'flight', 'from', 'atlanta', 'to', 'philadelphia', 'and', 'i', "'m", 'looking', 'for', 'the', 'cheapest', 'fare']
['O', 'O', 'O', 'O', 'O', 'B-fromloc.city_name', 'O', 'B-toloc.city_name', 'O', 'O', 'O', 'O', 'O', 'O', 'B-cost_relative', 'O']

['i', 'need', 'a', 'flight', 'from', 'toronto', 'to', 'montreal', 'reaching', 'montreal', 'early', 'on', 'friday']
['O', 'O', 'O', 'O', 'O', 'B-fromloc.city_name', 'O', 'B-toloc.city_name', 'O', 'B-toloc.city_name', 'B-arrive

In [5]:
trainer.num_threads = 4

ner = trainer.train()


In [6]:
print ("tags:", ner.get_possible_ner_tags())

tags: ['airline_name', 'flight_number', 'fromloc.city_name', 'toloc.city_name', 'depart_date.day_name', 'cost_relative', 'arrive_time.period_mod', 'arrive_date.day_name', 'depart_time.period_of_day', 'fromloc.airport_name']


In [7]:
output = 0
for sentence_a, label_a in zip(test_x, test_label):
    tokens = [idx2w[word].decode('utf8') for word in sentence_a]
    labels = [idx2la[label].decode('utf8') for label in label_a]
    entities = ner.extract_entities(tokens)
    print ("\nSentence: ", tokens)
    print ("\nTest Label:", labels)
    print ("\nEntities found:", entities)
    print ("\nTest results:", get_entities(labels))
    print ("\nNumber of entities detected:", len(entities))
    print ()
    for e in entities:
        range = e[0]
        tag = e[1]
        entity_text = " ".join(tokens[i] for i in range)
        print ("    " + tag + ": " + entity_text)
    output += 1
    if output > 2: break


Sentence:  ['i', 'would', 'like', 'to', 'find', 'a', 'flight', 'from', 'charlotte', 'to', 'las', 'vegas', 'that', 'makes', 'a', 'stop', 'in', 'st.', 'louis']

Test Label: ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-fromloc.city_name', 'O', 'B-toloc.city_name', 'I-toloc.city_name', 'O', 'O', 'O', 'O', 'O', 'B-stoploc.city_name', 'I-stoploc.city_name']

Entities found: [(range(2, 3), 'fromloc.city_name', 0.31079096411241164), (range(5, 6), 'depart_time.period_of_day', 0.14010175800500133), (range(8, 9), 'fromloc.city_name', 0.7084902255174335), (range(10, 12), 'toloc.city_name', 0.6191374784467214), (range(18, 19), 'toloc.city_name', 0.1356279520145114)]

Test results: [(8, 9, 'fromloc.city_name'), (10, 12, 'toloc.city_name'), (17, 19, 'stoploc.city_name')]

Number of entities detected: 5

    fromloc.city_name: like
    depart_time.period_of_day: a
    fromloc.city_name: charlotte
    toloc.city_name: las vegas
    toloc.city_name: louis

Sentence:  ['on', 'april', 'first', 'i', 'need'