In [1]:
import pandas as pd
import numpy as np
from utils.data_helper import get_markable_dataframe, get_embedding_variables
from model_builders.coreference_classifier import CoreferenceClassifierModelBuilder
from functools import reduce
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.models import load_model
from utils.clusterers import BestFirstClusterer, get_anaphora_scores_by_antecedent, ClosestFirstClusterer
from utils.scorers import MUCScorer, B3Scorer
from utils.data_structures import UFDS

In [2]:
embedding_indexes_file_path = 'helper_files/embedding/embedding_indexes.txt'
indexed_embedding_file_path = 'helper_files/embedding/indexed_embedding.txt'

word_vector, embedding_matrix, idx_by_word, word_by_idx = get_embedding_variables(embedding_indexes_file_path, indexed_embedding_file_path)

In [70]:
markables = get_markable_dataframe("data/testing/markables_with_predicted_singleton.csv", word_vector, idx_by_word)
singletons = set(markables[markables['is_singleton'].map(lambda x: True if x[1] > 0 else False)]['id'])
markables.head()

Unnamed: 0,id,text,is_pronoun,entity_type,is_proper_name,is_first_person,num_words,previous_words,next_words,is_singleton
0,1916,"[1258, 1259, 1955, 1389]",0,"[0, 1, 0, 0, 0, 1, 0, 0, 0, 0]",1,0,4,[],"[996, 377, 1156, 212, 26, 1258, 1956, 1183, 14...","[0.0, 1.0]"
1,1917,[212],1,"[0, 1, 0, 0, 0, 0, 0, 0, 0, 0]",0,0,1,"[1258, 1259, 1955, 1389, 996, 377, 1156]","[26, 1258, 1956, 1183, 1464, 24, 1156, 62, 422...","[1.0, 0.0]"
2,1918,"[1258, 1956, 1183]",0,"[0, 0, 0, 0, 0, 1, 0, 1, 0, 0]",1,0,3,"[1258, 1259, 1955, 1389, 996, 377, 1156, 212, 26]","[1464, 24, 1156, 62, 422, 1218, 24, 1409, 1156...","[0.0, 1.0]"
3,1919,"[1464, 24, 1156]",0,"[1, 1, 0, 0, 0, 0, 0, 0, 0, 0]",0,0,3,"[1955, 1389, 996, 377, 1156, 212, 26, 1258, 19...","[62, 422, 1218, 24, 1409, 1156, 874, 342, 212,...","[0.0, 1.0]"
4,1920,[422],0,"[1, 0, 0, 0, 0, 0, 0, 0, 0, 0]",0,0,1,"[1156, 212, 26, 1258, 1956, 1183, 1464, 24, 11...","[1218, 24, 1409, 1156, 874, 342, 212, 404, 121...","[0.0, 1.0]"


In [4]:
max_sentence = 6

pairs = pd.read_csv("data/testing/mention_pairs.csv")
pairs = pairs[pairs['sentence_distance'] < max_sentence]

label = np.vstack(to_categorical(pairs.is_coreference, num_classes=2))
label_chains = ClosestFirstClusterer().get_chains(get_anaphora_scores_by_antecedent(pairs.m1_id, pairs.m2_id, label))

pairs.head()

Unnamed: 0,m1_id,m2_id,is_exact_match,is_words_match,is_substring,is_abbreviation,is_appositive,is_nearest_candidate,sentence_distance,word_distance,markable_distance,is_coreference
0,1916,1917,0,0,0,0,0,1,0,3,1,1
1,1916,1918,0,0,0,0,0,0,0,5,2,0
2,1916,1919,0,0,0,0,0,0,0,8,3,0
3,1916,1920,0,0,0,0,0,0,0,12,4,0
4,1916,1921,0,0,0,0,0,0,0,13,5,0


In [5]:
max_text_length = 10
max_prev_words_length = 10
max_next_words_length = 10

def get_data(markable_ids):
    indices = reduce(lambda a, b: a + [b], map(lambda a: markables.index[markables['id'] == a].tolist()[0], markable_ids), [])
    data = markables.loc[indices]
    
    data_text = pad_sequences(data.text, maxlen=max_text_length, padding='post')
    data_previous_words = pad_sequences(data.previous_words.map(lambda seq: seq[(-1*max_prev_words_length):]), maxlen=max_prev_words_length, padding='pre')
    data_next_words = pad_sequences(data.next_words.map(lambda seq: seq[:max_next_words_length]), maxlen=max_next_words_length, padding='post')
    data_syntactic = data[['is_pronoun', 'entity_type', 'is_proper_name', 'is_first_person']]

    data_syntactic = np.array(list(map(lambda p: reduce(lambda x,y: x + y, [i if type(i) is list else [i] for i in p]), data_syntactic.values)))
    is_singleton = np.vstack(data.is_singleton)
    
    return data_text, data_previous_words, data_next_words, data_syntactic, is_singleton

def get_pair_data(markable_ids_1, markable_ids_2):
    text_1, prev_1, next_1, syntactic_1, is_singleton_1 = get_data(markable_ids_1)
    text_2, prev_2, next_2, syntactic_2, is_singleton_2 = get_data(markable_ids_2)
    
    return text_1, text_2, prev_1, prev_2, next_1, next_2, syntactic_1, syntactic_2, is_singleton_1, is_singleton_2

def get_relation_data(mention_pairs):
    return mention_pairs[['is_exact_match', 'is_words_match', 'is_substring', 'is_abbreviation', 'is_appositive', 'is_nearest_candidate', 'sentence_distance', 'word_distance', 'markable_distance']]

# Compute Baseline Score

In [126]:
baseline_result_file_path = 'baseline/test_result.txt'

baseline_ufds = UFDS()

for m1, m2 in zip(pairs.m1_id, pairs.m2_id):
    baseline_ufds.init_id(m1, m2)
    
for line in open(baseline_result_file_path, 'r').readlines():
    print(line)
    line = line.split(', ')
    baseline_ufds.join(int(line[0]), int(line[1]))

baseline_chains = baseline_ufds.get_chain_list()

print('MUC: ', MUCScorer().get_scores(baseline_chains, label_chains))
print('B3: ', B3Scorer().get_scores(baseline_chains, label_chains))

1916, 1930, PT\NNP Astra\NNP Otoparts\NNP Tbk\NNP, PT\NNP Astra\NNP Otoparts\NNP

1918, 1931, PT\NNP Exedy\NNP Indonesia\NNP, PT\NNP Exedy\NNP Indonesia\NNP

1927, 1953, Kami\PRP, Menkeu\NNP Sri\NNP Mulyani\NNP Indrawati\NNP

1933, 1934, Sekretaris\NNP Perusahaan\NNP Astra\NNP Otoparts,\NNP, Kartina\NNP Rahayu\NNP

1933, 1937, Sekretaris\NNP Perusahaan\NNP Astra\NNP Otoparts,\NNP, dia,\PRP

1933, 1940, Sekretaris\NNP Perusahaan\NNP Astra\NNP Otoparts,\NNP, Dia\PRP

1940, 1946, Dia\PRP, nya\PRP

1940, 1952, Dia\PRP, nya.\PRP

1940, 1953, Dia\PRP, Menkeu\NNP Sri\NNP Mulyani\NNP Indrawati\NNP

1940, 2255, Dia\PRP, Hartadi\NNP A\NNP Sarwono\NNP

1964, 2255, Dia\PRP, Hartadi\NNP A\NNP Sarwono\NNP

1964, 2266, Dia\PRP, Hartadi\NNP

1964, 2269, Dia\PRP, Jakarta.\NNP

2252, 2253, Deputi\NNP Gubernur\NNP, Bank\NNP Indonesia\NNP

2255, 2258, Hartadi\NNP A\NNP Sarwono\NNP, Ia\PRP

2255, 2266, Hartadi\NNP A\NNP Sarwono\NNP, Hartadi\NNP

2258, 2266, Ia\PRP, Hartadi\NNP

2258, 2269, Ia\PRP, Jakarta.

# Test Models

In [7]:
text_1, text_2, prev_1, prev_2, next_1, next_2, syntactic_1, syntactic_2, is_singleton_1, is_singleton_2 = get_pair_data(pairs.m1_id, pairs.m2_id)
relation = get_relation_data(pairs)

## Soon

### Words

In [31]:
words_syntactic_model_1 = load_model('models/coreference_classifiers/words_syntactic_soon_10.model')

In [32]:
syntactic_1_pred = words_syntactic_model_1.predict([text_1, text_2, syntactic_1, syntactic_2, relation], verbose=1)



In [71]:
predz = get_anaphora_scores_by_antecedent(pairs.m1_id, pairs.m2_id, syntactic_1_pred)
predz2 = get_anaphora_scores_by_antecedent(pairs.m1_id, pairs.m2_id, syntactic_1_pred, singletons)
labz = get_anaphora_scores_by_antecedent(pairs.m1_id, pairs.m2_id, label)

In [121]:
pred_chains = ClosestFirstClusterer().get_chains(predz, threshold=0.7)
pred_chains2 = ClosestFirstClusterer().get_chains(predz2, threshold=0.7)

In [122]:
print(MUCScorer().get_scores(pred_chains, label_chains))
print(MUCScorer().get_scores(pred_chains2, label_chains))
print()
print(B3Scorer().get_scores(pred_chains, label_chains))
print(B3Scorer().get_scores(pred_chains2, label_chains))

(0.3508771929824561, 0.5263157894736842, 0.42105263157894735)
(0.49411764705882355, 0.5526315789473685, 0.5217391304347826)

(0.21866613466912124, 0.4686777623670827, 0.2982028505500351)
(0.1807609087630725, 0.47515025427646795, 0.26189092853369134)


In [123]:
(0.47692307692307695, 0.40789473684210525, 0.43971631205673756)
(0.5714285714285714, 0.21052631578947367, 0.3076923076923077)

(0.5714285714285714, 0.21052631578947367, 0.3076923076923077)

In [124]:
def get_markable_text(idx):
#     return {idx: (markables[markables['id'] == idx].text.values)}
    return [word_by_idx[x] for x in markables[markables['id'] == idx].text.values[0]]

In [125]:
[[get_markable_text(b) for b in a] for a in pred_chains2 if len(a) > 1]

[[['bi'], ['bi']],
 [['ia'], ['nya'], ['nya'], ['nya'], ['ia'], ['nya']],
 [['menteri', 'keuangan', 'sri', 'mulyan'],
  ['sri', 'mulyan'],
  ['mulyan'],
  ['mulyan'],
  ['nya'],
  ['direktur', 'aal'],
  ['santosa'],
  ['nya'],
  ['nya'],
  ['nya']],
 [['bank', 'mandiri'], ['bank', 'mandiri'], ['bank', 'mandiri']],
 [['adb'], ['adb'], ['adb']],
 [['miranda', 's', 'goeltom'],
  ['miranda'],
  ['miranda'],
  ['nya'],
  ['nya'],
  ['deputi', 'gubernur', 'senior', 'bi'],
  ['miranda', 's', 'goeltom'],
  ['miranda'],
  ['nya'],
  ['itung', 'bi'],
  ['nya'],
  ['nya'],
  ['dia'],
  ['ia'],
  ['nya'],
  ['nya'],
  ['nya'],
  ['nya'],
  ['menteri', 'keuangan', 'sri', 'mulyan', 'indrawat'],
  ['nya'],
  ['menko', 'perekonomian', 'boediono'],
  ['nya'],
  ['nya'],
  ['nya'],
  ['nya'],
  ['findo'],
  ['trje'],
  ['analis', 'findo', 'ronald', 'hertanto'],
  ['nya'],
  ['dia'],
  ['trje'],
  ['trje'],
  ['nya'],
  ['nya'],
  ['nya'],
  ['nya'],
  ['nya'],
  ['nya'],
  ['nya'],
  ['nya'],
  ['nya'],

In [127]:
[x for x in pred_chains2 if len(x) > 1]

[[2254, 2268],
 [2258, 2274, 2278, 2290, 2303, 2308],
 [2330, 2337, 2367, 2375, 2383, 2393, 2394, 2399, 2405, 2413],
 [2510, 2530, 2546],
 [2718, 2725, 2760],
 [2426,
  2436,
  2450,
  2464,
  2466,
  2469,
  2470,
  2480,
  2488,
  2489,
  2517,
  2521,
  2528,
  2535,
  2545,
  2551,
  2552,
  2554,
  2565,
  2592,
  2603,
  2620,
  2630,
  2636,
  2640,
  2652,
  2655,
  2668,
  2669,
  2672,
  2677,
  2693,
  2694,
  2702,
  2711,
  2727,
  2735,
  2738,
  2744,
  2752,
  2759,
  2766,
  2778,
  2786,
  2797,
  2813,
  2837,
  2847,
  2865,
  2869,
  2870,
  2881,
  2882,
  2925,
  2943,
  2950,
  2951,
  2956,
  2976],
 [1925, 1932],
 [1917, 1923, 1934, 1937, 1940, 1946],
 [1952, 1953, 1964]]

In [97]:
pred_chains

[[2245],
 [2246],
 [2247],
 [2248],
 [2249],
 [2250],
 [2251],
 [2252],
 [2253],
 [2255],
 [2256],
 [2257],
 [2259],
 [2260],
 [2261],
 [2262],
 [2263],
 [2264],
 [2265],
 [2267],
 [2269],
 [2271],
 [2272],
 [2273],
 [2275],
 [2276],
 [2277],
 [2279],
 [2280],
 [2281],
 [2282],
 [2283],
 [2284],
 [2285],
 [2266, 2270, 2286],
 [2287],
 [2254, 2268, 2288],
 [2289],
 [2291],
 [2292],
 [2293],
 [2294],
 [2295],
 [2296],
 [2297],
 [2298],
 [2299],
 [2300],
 [2301],
 [2302],
 [2304],
 [2305],
 [2306],
 [2307],
 [2258, 2274, 2278, 2290, 2303, 2308],
 [2309],
 [2310],
 [2312],
 [2313],
 [2314],
 [2315],
 [2311, 2316],
 [2317],
 [2318],
 [2319],
 [2320],
 [2321],
 [2322],
 [2323],
 [2324],
 [2325],
 [2326],
 [2327],
 [2328],
 [2329],
 [2331],
 [2332],
 [2333],
 [2334],
 [2335],
 [2336],
 [2338],
 [2340],
 [2341],
 [2342],
 [2343],
 [2344],
 [2345],
 [2346],
 [2347],
 [2348],
 [2349],
 [2350],
 [2351],
 [2352],
 [2353],
 [2354],
 [2355],
 [2356],
 [2357],
 [2358],
 [2359],
 [2360],
 [2361],
 [23