In [1]:
import pandas as pd
import numpy as np
from utils.data_helper import get_markable_dataframe, get_embedding_variables
from model_builders.coreference_classifier import CoreferenceClassifierModelBuilder
from functools import reduce
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.models import load_model
from utils.clusterers import BestFirstClusterer, get_anaphora_scores_by_antecedent, ClosestFirstClusterer
from utils.scorers import MUCScorer, B3Scorer
from utils.data_structures import UFDS

In [2]:
embedding_indexes_file_path = 'helper_files/embedding/embedding_indexes.txt'
indexed_embedding_file_path = 'helper_files/embedding/indexed_embedding.txt'

word_vector, embedding_matrix, idx_by_word, word_by_idx = get_embedding_variables(embedding_indexes_file_path, indexed_embedding_file_path)

In [3]:
markables = get_markable_dataframe("data/testing/markables_with_predicted_singleton.csv", word_vector, idx_by_word)
singletons = set(markables[markables['is_singleton'].map(lambda x: True if x[1] > 0 else False)]['id'])
markables.head()

Unnamed: 0,id,text,is_pronoun,entity_type,is_proper_name,is_first_person,previous_words,next_words,is_singleton
0,1916,"[1263, 1264, 1968, 1395]",0,"[0, 0, 0, 0, 1, 0, 1, 0, 0, 0]",1,0,[],"[999, 379, 1161, 213, 27, 1263, 1969, 1188, 14...","[0.0, 1.0]"
1,1917,[213],1,"[0, 0, 0, 0, 0, 0, 1, 0, 0, 0]",0,0,"[1263, 1264, 1968, 1395, 999, 379, 1161]","[27, 1263, 1969, 1188, 1470, 25, 1161, 63, 424...","[1.0, 0.0]"
2,1918,"[1263, 1969, 1188]",0,"[0, 0, 0, 0, 1, 0, 0, 0, 0, 1]",1,0,"[1263, 1264, 1968, 1395, 999, 379, 1161, 213, 27]","[1470, 25, 1161, 63, 424, 1223, 25, 1415, 1161...","[0.0, 1.0]"
3,1919,"[1470, 25, 1161]",0,"[0, 1, 0, 0, 0, 0, 1, 0, 0, 0]",0,0,"[1968, 1395, 999, 379, 1161, 213, 27, 1263, 19...","[63, 424, 1223, 25, 1415, 1161, 876, 344, 213,...","[0.0, 1.0]"
4,1920,[424],0,"[0, 1, 0, 0, 0, 0, 0, 0, 0, 0]",0,0,"[1161, 213, 27, 1263, 1969, 1188, 1470, 25, 11...","[1223, 25, 1415, 1161, 876, 344, 213, 406, 122...","[0.0, 1.0]"


In [4]:
pairs = pd.read_csv("data/testing/mention_pairs.csv")

label = np.vstack(to_categorical(pairs.is_coreference, num_classes=2))
label_chains = ClosestFirstClusterer().get_chains(get_anaphora_scores_by_antecedent(pairs.m1_id, pairs.m2_id, label))

pairs.head()

Unnamed: 0,m1_id,m2_id,is_exact_match,is_words_match,is_substring,is_abbreviation,is_appositive,is_nearest_candidate,sentence_distance,word_distance,markable_distance,is_coreference
0,1916,1917,0,0,0,0,0,1,0,3,1,1
1,1916,1918,0,0,0,0,0,0,0,5,2,0
2,1916,1919,0,0,0,0,0,0,0,8,3,0
3,1916,1920,0,0,0,0,0,0,0,12,4,0
4,1916,1921,0,0,0,0,0,0,0,13,5,0


In [5]:
max_text_length = 10
max_prev_words_length = 10
max_next_words_length = 10

def get_data(markable_ids):
    indices = reduce(lambda a, b: a + [b], map(lambda a: markables.index[markables['id'] == a].tolist()[0], markable_ids), [])
    data = markables.loc[indices]
    
    data_text = pad_sequences(data.text, maxlen=max_text_length, padding='post')
    data_previous_words = pad_sequences(data.previous_words.map(lambda seq: seq[(-1*max_prev_words_length):]), maxlen=max_prev_words_length, padding='pre')
    data_next_words = pad_sequences(data.next_words.map(lambda seq: seq[:max_next_words_length]), maxlen=max_next_words_length, padding='post')
    data_syntactic = data[['is_pronoun', 'entity_type', 'is_proper_name', 'is_first_person']]

    data_syntactic = np.array(list(map(lambda p: reduce(lambda x,y: x + y, [i if type(i) is list else [i] for i in p]), data_syntactic.values)))
    is_singleton = np.vstack(data.is_singleton)
    
    return data_text, data_previous_words, data_next_words, data_syntactic, is_singleton

def get_pair_data(markable_ids_1, markable_ids_2):
    text_1, prev_1, next_1, syntactic_1, is_singleton_1 = get_data(markable_ids_1)
    text_2, prev_2, next_2, syntactic_2, is_singleton_2 = get_data(markable_ids_2)
    
    return text_1, text_2, prev_1, prev_2, next_1, next_2, syntactic_1, syntactic_2, is_singleton_1, is_singleton_2

def get_relation_data(mention_pairs):
    return mention_pairs[['is_exact_match', 'is_words_match', 'is_substring', 'is_abbreviation', 'is_appositive', 'is_nearest_candidate', 'sentence_distance', 'word_distance', 'markable_distance']]

# Compute Baseline Score

In [237]:
baseline_result_file_path = 'baseline/test_result.txt'

baseline_ufds = UFDS()

for m1, m2 in zip(pairs.m1_id, pairs.m2_id):
    baseline_ufds.init_id(m1, m2)
    
for line in open(baseline_result_file_path, 'r').readlines():
    line = line.split(', ')
    baseline_ufds.join(int(line[0]), int(line[1]))

baseline_chains = baseline_ufds.get_chain_list()

print('MUC: ', MUCScorer().get_scores(baseline_chains, label_chains))
print('B3: ', B3Scorer().get_scores(baseline_chains, label_chains))

MUC:  (0.5544554455445545, 0.7272727272727273, 0.6292134831460674)
B3:  (0.3124361294443262, 0.6732829670329671, 0.4268110965737344)


# Test Models

In [7]:
text_1, text_2, prev_1, prev_2, next_1, next_2, syntactic_1, syntactic_2, is_singleton_1, is_singleton_2 = get_pair_data(pairs.m1_id, pairs.m2_id)
relation = get_relation_data(pairs)

## Budi

### Words + Context + Syntactic

In [215]:
words_syntactic_model_1 = load_model('models/coreference_classifiers/words_context_syntactic_budi_5.model')

In [216]:
syntactic_1_pred = words_syntactic_model_1.predict([text_1, text_2, prev_1, prev_2, next_1, next_2, syntactic_1, syntactic_2, relation], verbose=1)



In [217]:
predz = get_anaphora_scores_by_antecedent(pairs.m1_id, pairs.m2_id, syntactic_1_pred)
predz2 = get_anaphora_scores_by_antecedent(pairs.m1_id, pairs.m2_id, syntactic_1_pred, singletons)
labz = get_anaphora_scores_by_antecedent(pairs.m1_id, pairs.m2_id, label)

In [230]:
pred_chains = BestFirstClusterer().get_chains(predz, threshold=0.005)
pred_chains2 = BestFirstClusterer().get_chains(predz2, threshold=0.005)

In [231]:
print(MUCScorer().get_scores(pred_chains, label_chains))
print(MUCScorer().get_scores(pred_chains2, label_chains))
print()
print(B3Scorer().get_scores(pred_chains, label_chains))
print(B3Scorer().get_scores(pred_chains2, label_chains))

(0.21241830065359477, 0.8441558441558441, 0.3394255874673629)
(0.5978260869565217, 0.7142857142857143, 0.6508875739644971)

(0.0807532311295752, 0.8297619047619047, 0.14718251731674847)
(0.3842241349185793, 0.6573489010989012, 0.4849766730331626)


In [232]:
def get_markable_text(idx):
#     return {idx: (markables[markables['id'] == idx].text.values)}
    return [word_by_idx[x] for x in markables[markables['id'] == idx].text.values[0]]

In [233]:
[[get_markable_text(b) for b in a] for a in pred_chains2 if len(a) > 1]

[[['bank', 'indonesia'],
  ['bi'],
  ['hartadi', 'a', 'sarwono'],
  ['ia'],
  ['hartadi'],
  ['bi'],
  ['hartadi'],
  ['nya'],
  ['nya'],
  ['hartadi'],
  ['nya'],
  ['ia'],
  ['nya'],
  ['ansari'],
  ['ansari']],
 [['menteri', 'keuangan', 'sri', 'mulyani'],
  ['sri', 'mulyani'],
  ['mulyani'],
  ['mulyani']],
 [['aali'], ['nya'], ['direktur', 'aali'], ['santosa'], ['nya']],
 [['deputi', 'senior'],
  ['bank', 'indonesia'],
  ['bi'],
  ['miranda', 's', 'goeltom'],
  ['miranda'],
  ['miranda'],
  ['nya'],
  ['nya']],
 [['deputi', 'gubernur', 'senior', 'bi'],
  ['miranda', 's', 'goeltom'],
  ['miranda'],
  ['nya']],
 [['bank', 'mandiri'],
  ['nya'],
  ['nya'],
  ['direktur', 'teknologi', 'dan', 'operasional', 'bank', 'mandiri'],
  ['sasmita'],
  ['dia'],
  ['bank', 'mandiri'],
  ['ia'],
  ['nya'],
  ['bank', 'mandiri'],
  ['nya']],
 [['menteri', 'keuangan', 'sri', 'mulyani', 'indrawati'],
  ['menko', 'perekonomian', 'boediono'],
  ['nya'],
  ['boediono'],
  ['nya'],
  ['nya'],
  ['nya']],

In [235]:
[[get_markable_text(b) for b in a] for a in baseline_chains if len(a) > 1]

[[['deputi', 'gubernur'], ['bank', 'indonesia']],
 [['hartadi', 'a', 'sarwono'],
  ['ia'],
  ['hartadi'],
  ['jakarta'],
  ['hartadi'],
  ['hartadi'],
  ['kami'],
  ['sekretaris', 'perusahaan', 'astra', 'otoparts'],
  ['kartina', 'rahayu'],
  ['dia'],
  ['dia'],
  ['nya'],
  ['nya'],
  ['menkeu', 'sri', 'mulyani', 'indrawati'],
  ['dia']],
 [['ia'],
  ['dirjend',
   'industri',
   'logam',
   'mesin',
   'tekstil',
   'departemen',
   'perindustrian',
   'ansari',
   'bukhari'],
  ['ansari'],
  ['ansari']],
 [['indonesia', 'investor', 'forum'], ['indonesia', 'investor', 'forum']],
 [['menteri', 'keuangan', 'sri', 'mulyani'],
  ['sri', 'mulyani'],
  ['mulyani'],
  ['nya']],
 [['direktur', 'aali'],
  ['santosa'],
  ['nya'],
  ['nya'],
  ['deputi', 'senior'],
  ['bank', 'indonesia'],
  ['miranda', 's', 'goeltom'],
  ['miranda'],
  ['miranda'],
  ['kami'],
  ['nya'],
  ['deputi', 'gubernur', 'senior', 'bi'],
  ['miranda', 's', 'goeltom'],
  ['miranda']],
 [['pertumbuhan', 'ekonomi', 'indon

In [236]:
[[get_markable_text(b) for b in a] for a in label_chains if len(a) > 1]

[[['hartadi', 'a', 'sarwono'],
  ['ia'],
  ['hartadi'],
  ['hartadi'],
  ['hartadi'],
  ['ia']],
 [['dirjend',
   'industri',
   'logam',
   'mesin',
   'tekstil',
   'departemen',
   'perindustrian',
   'ansari',
   'bukhari'],
  ['ansari'],
  ['ansari']],
 [['pdb'], ['pdb'], ['pdb']],
 [['menteri', 'keuangan', 'sri', 'mulyani'],
  ['sri', 'mulyani'],
  ['mulyani'],
  ['mulyani']],
 [['direktur', 'aali'], ['santosa'], ['nya']],
 [['miranda', 's', 'goeltom'], ['miranda'], ['miranda'], ['nya'], ['nya']],
 [['deputi', 'gubernur', 'senior', 'bi'],
  ['miranda', 's', 'goeltom'],
  ['miranda'],
  ['nya']],
 [['direktur', 'teknologi', 'dan', 'operasional', 'bank', 'mandiri'],
  ['sasmita'],
  ['dia'],
  ['ia'],
  ['nya']],
 [['bank', 'mandiri'],
  ['nya'],
  ['kami'],
  ['bank', 'mandiri'],
  ['bank', 'mandiri'],
  ['nya']],
 [['menko', 'perekonomian', 'boediono'], ['nya'], ['boediono'], ['nya']],
 [['menteri', 'keuangan', 'sri', 'mulyani', 'indrawati'], ['nya'], ['nya']],
 [['analis', 'pefi