In [263]:
import pandas as pd
import numpy as np
from utils.data_helper import get_markable_dataframe, get_embedding_variables
from model_builders.coreference_classifier import CoreferenceClassifierModelBuilder
from functools import reduce
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.models import load_model
from utils.clusterers import BestFirstClusterer, get_anaphora_scores_by_antecedent, ClosestFirstClusterer
from utils.scorers import MUCScorer, B3Scorer
from utils.data_structures import UFDS

In [2]:
embedding_indexes_file_path = 'helper_files/embedding/embedding_indexes.txt'
indexed_embedding_file_path = 'helper_files/embedding/indexed_embedding.txt'

word_vector, embedding_matrix, idx_by_word, word_by_idx = get_embedding_variables(embedding_indexes_file_path, indexed_embedding_file_path)

In [229]:
markables = get_markable_dataframe("data/testing/markables_with_predicted_singleton.csv", word_vector, idx_by_word)
singletons = set(markables[markables['is_singleton'].map(lambda x: True if x[1] > 0 else False)]['id'])
markables.head()

Unnamed: 0,id,text,is_pronoun,entity_type,is_proper_name,is_first_person,num_words,previous_words,next_words,is_singleton
0,1916,"[1258, 1259, 1955, 1389]",0,"[0, 0, 1, 0, 0, 0, 0, 0, 1, 0]",1,0,4,[],"[996, 377, 1156, 212, 26, 1258, 1956, 1183, 14...","[0.0, 1.0]"
1,1917,[212],1,"[0, 0, 0, 0, 0, 0, 0, 0, 1, 0]",0,0,1,"[1258, 1259, 1955, 1389, 996, 377, 1156]","[26, 1258, 1956, 1183, 1464, 24, 1156, 62, 422...","[1.0, 0.0]"
2,1918,"[1258, 1956, 1183]",0,"[0, 0, 1, 0, 0, 0, 0, 1, 0, 0]",1,0,3,"[1258, 1259, 1955, 1389, 996, 377, 1156, 212, 26]","[1464, 24, 1156, 62, 422, 1218, 24, 1409, 1156...","[0.0, 1.0]"
3,1919,"[1464, 24, 1156]",0,"[1, 0, 0, 0, 0, 0, 0, 0, 1, 0]",0,0,3,"[1955, 1389, 996, 377, 1156, 212, 26, 1258, 19...","[62, 422, 1218, 24, 1409, 1156, 874, 342, 212,...","[0.0, 1.0]"
4,1920,[422],0,"[1, 0, 0, 0, 0, 0, 0, 0, 0, 0]",0,0,1,"[1156, 212, 26, 1258, 1956, 1183, 1464, 24, 11...","[1218, 24, 1409, 1156, 874, 342, 212, 404, 121...","[0.0, 1.0]"


In [261]:
max_sentence = 6

pairs = pd.read_csv("data/testing/mention_pairs.csv", nrows=250000)
pairs = pairs[pairs['sentence_distance'] < max_sentence]

label = np.vstack(to_categorical(pairs.is_coreference, num_classes=2))
label_chains = ClosestFirstClusterer().get_chains(get_anaphora_scores_by_antecedent(pairs.m1_id, pairs.m2_id, label))

pairs.head()

Unnamed: 0,m1_id,m2_id,is_exact_match,is_words_match,is_substring,is_abbreviation,is_appositive,is_nearest_candidate,sentence_distance,word_distance,markable_distance,is_coreference
0,2245,2246,0,0,0,0,0,1,0,3,1,0
1,2245,2247,0,0,0,0,0,0,0,5,2,0
2,2245,2248,0,0,0,0,0,0,0,7,3,0
3,2245,2249,0,0,0,0,0,0,0,9,4,0
4,2245,2250,0,0,0,0,0,0,0,12,5,0


In [231]:
max_text_length = 10
max_prev_words_length = 10
max_next_words_length = 10

def get_data(markable_ids):
    indices = reduce(lambda a, b: a + [b], map(lambda a: markables.index[markables['id'] == a].tolist()[0], markable_ids), [])
    data = markables.loc[indices]
    
    data_text = pad_sequences(data.text, maxlen=max_text_length, padding='post')
    data_previous_words = pad_sequences(data.previous_words.map(lambda seq: seq[(-1*max_prev_words_length):]), maxlen=max_prev_words_length, padding='pre')
    data_next_words = pad_sequences(data.next_words.map(lambda seq: seq[:max_next_words_length]), maxlen=max_next_words_length, padding='post')
    data_syntactic = data[['is_pronoun', 'entity_type', 'is_proper_name', 'is_first_person']]

    data_syntactic = np.array(list(map(lambda p: reduce(lambda x,y: x + y, [i if type(i) is list else [i] for i in p]), data_syntactic.values)))
    is_singleton = np.vstack(data.is_singleton)
    
    return data_text, data_previous_words, data_next_words, data_syntactic, is_singleton

def get_pair_data(markable_ids_1, markable_ids_2):
    text_1, prev_1, next_1, syntactic_1, is_singleton_1 = get_data(markable_ids_1)
    text_2, prev_2, next_2, syntactic_2, is_singleton_2 = get_data(markable_ids_2)
    
    return text_1, text_2, prev_1, prev_2, next_1, next_2, syntactic_1, syntactic_2, is_singleton_1, is_singleton_2

def get_relation_data(mention_pairs):
    return mention_pairs[['is_exact_match', 'is_words_match', 'is_substring', 'is_abbreviation', 'is_appositive', 'is_nearest_candidate', 'sentence_distance', 'word_distance', 'markable_distance']]

# Compute Baseline Score

In [265]:
baseline_result_file_path = 'baseline/test_result.txt'

baseline_ufds = UFDS()

for m1, m2 in zip(pairs.m1_id, pairs.m2_id):
    baseline_ufds.init_id(m1, m2)
    
for line in open(baseline_result_file_path, 'r').readlines():
    line = line.split(', ')
    baseline_ufds.join(int(line[0]), int(line[1]))

baseline_chains = baseline_ufds.get_chain_list()

print('MUC: ', MUCScorer().get_scores(baseline_chains, label_chains))
print('B3: ', B3Scorer().get_scores(baseline_chains, label_chains))

MUC:  (0.26732673267326734, 0.7297297297297297, 0.391304347826087)
B3:  (0.14114949258391882, 0.6940000000000001, 0.2345873372925423)


# Test Models

In [232]:
text_1, text_2, prev_1, prev_2, next_1, next_2, syntactic_1, syntactic_2, is_singleton_1, is_singleton_2 = get_pair_data(pairs.m1_id, pairs.m2_id)
relation = get_relation_data(pairs)

## Soon

### Words

In [246]:
words_syntactic_model_1 = load_model('models/coreference_classifiers/words_gilang.model')

In [247]:
syntactic_1_pred = words_syntactic_model_1.predict([text_1, text_2], verbose=1)



In [248]:
predz = get_anaphora_scores_by_antecedent(pairs.m1_id, pairs.m2_id, syntactic_1_pred)
predz2 = get_anaphora_scores_by_antecedent(pairs.m1_id, pairs.m2_id, syntactic_1_pred, singletons)
labz = get_anaphora_scores_by_antecedent(pairs.m1_id, pairs.m2_id, label)

In [262]:
pred_chains = ClosestFirstClusterer().get_chains(predz, threshold=0.4)
pred_chains2 = ClosestFirstClusterer().get_chains(predz2, threshold=0.4)

In [258]:
print(MUCScorer().get_scores(pred_chains, lab_chains))
print(MUCScorer().get_scores(pred_chains2, lab_chains))
print()
print(B3Scorer().get_scores(pred_chains, lab_chains))
print(B3Scorer().get_scores(pred_chains2, lab_chains))

(0.46808510638297873, 0.5945945945945946, 0.5238095238095238)
(0.625, 0.5405405405405406, 0.5797101449275363)

(0.39356118791602657, 0.5619999999999999, 0.46293505932608897)
(0.5800405268490375, 0.5013333333333334, 0.5378226004919506)


In [266]:
label_chains

[[2245],
 [2246],
 [2247],
 [2248],
 [2249],
 [2250],
 [2251],
 [2252],
 [2253],
 [2254],
 [2256],
 [2257],
 [2259],
 [2260],
 [2261],
 [2262],
 [2263],
 [2264],
 [2265],
 [2267],
 [2268],
 [2269],
 [2271],
 [2272],
 [2273],
 [2274],
 [2275],
 [2276],
 [2277],
 [2278],
 [2279],
 [2280],
 [2281],
 [2282],
 [2283],
 [2284],
 [2285],
 [2287],
 [2288],
 [2289],
 [2290],
 [2291],
 [2292],
 [2293],
 [2294],
 [2295],
 [2296],
 [2297],
 [2298],
 [2299],
 [2300],
 [2301],
 [2302],
 [2255, 2258, 2266, 2270, 2286, 2303],
 [2304],
 [2305],
 [2306],
 [2307],
 [2308],
 [2309],
 [2310],
 [2312],
 [2313],
 [2314],
 [2315],
 [2317],
 [2318],
 [2319],
 [2320],
 [2321],
 [2322],
 [2311, 2316, 2323],
 [2324],
 [2325],
 [2326],
 [2327],
 [2328],
 [2329],
 [2331],
 [2332],
 [2334],
 [2335],
 [2336],
 [2338],
 [2339],
 [2340],
 [2341],
 [2342],
 [2344],
 [2345],
 [2346],
 [2333, 2343, 2347],
 [2348],
 [2349],
 [2350],
 [2351],
 [2352],
 [2353],
 [2354],
 [2355],
 [2356],
 [2357],
 [2358],
 [2359],
 [2360],
 