In [11]:
import json
import pandas as pd
import numpy as np
from utils.data_helper import get_markable_dataframe, get_embedding_variables, get_sentence_variables, get_document_id_variables, get_phrases_and_nodes
from model_builders.coreference_classifier import CoreferenceClassifierModelBuilder
from functools import reduce
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.models import load_model
from utils.clusterers import BestFirstClusterer, get_anaphora_scores_by_antecedent, ClosestFirstClusterer
from utils.scorers import MUCScorer, B3Scorer, CEAFeScorer, AverageScorer
from utils.data_structures import UFDS
from xml.etree import ElementTree
from string import punctuation
from IPython.display import HTML, display
from tabulate import tabulate
from html import escape

ModuleNotFoundError: No module named 'tabulate'

In [2]:
pairs = pd.read_csv("data/testing/mention_pairs.csv")

label = np.vstack(to_categorical(pairs.is_coreference, num_classes=2))
label_chains = ClosestFirstClusterer().get_chains(get_anaphora_scores_by_antecedent(pairs.m1_id, pairs.m2_id, label))
# label_chains = sorted(list(filter(lambda x: len(x) > 1, label_chains)))

In [3]:
sentence_id_by_markable_id, markable_ids_by_sentence_id = get_sentence_variables('data/full.xml')
document_id_by_sentence_id, document_id_by_markable_id, sentence_ids_by_document_id, markable_ids_by_document_id = get_document_id_variables('data/document_id.csv', markable_ids_by_sentence_id)

data = ElementTree.parse('data/full.xml')

root = data.getroot()
parent_map = {c: p for p in root.iter() for c in p}

phrases, nodes, phrase_id_by_node_id = get_phrases_and_nodes(UFDS(), root)

phrases_by_sentence_id = {}

aneh = []
for phrase in phrases:
    sentence = parent_map[phrase]
    sentence_id = int(sentence.attrib['id'])
    
    if sentence_id not in phrases_by_sentence_id:
        phrases_by_sentence_id[sentence_id] = []
    
    if 'coref' in phrase.attrib:
        if document_id_by_markable_id[int(phrase.attrib['coref'])] != document_id_by_sentence_id[sentence_id]:
            aneh.append(document_id_by_markable_id[int(phrase.attrib['coref'])])
            aneh.append(document_id_by_sentence_id[sentence_id])
            
    phrases_by_sentence_id[sentence_id].append(phrase)

In [4]:
baseline_result_file_path = 'baseline/suherik_and_purwarianti/test_result.txt'

baseline_ufds = UFDS()

for m1, m2 in zip(pairs.m1_id, pairs.m2_id):
    baseline_ufds.init_id(m1, m2)
    
for line in open(baseline_result_file_path, 'r').readlines():
    line = line.split(', ')
    m1_id, m2_id = int(line[0]), int(line[1])
    
    if document_id_by_markable_id[m1_id] == document_id_by_markable_id[m2_id]:
        baseline_ufds.join(m1_id, m2_id)

baseline_chains = baseline_ufds.get_chain_list()

print('MUC: ', MUCScorer().get_scores(baseline_chains, label_chains))
print('B3: ', B3Scorer().get_scores(baseline_chains, label_chains))
print('Average: ', AverageScorer([MUCScorer(), B3Scorer(), CEAFeScorer()]).get_scores(baseline_chains, label_chains))

MUC:  (0.6395348837209303, 0.7051282051282052, 0.6707317073170733)
B3:  (0.5041087231352718, 0.6379818594104308, 0.5631991462778547)
Average:  (0.4185637051277797, 0.4185637051277797, 0.41856370512777963)


In [5]:
data = ElementTree.parse('data/testing/data.xml')

root = data.getroot()

test_document_ids = set()
for sentence in root:
    test_document_ids.add(document_id_by_sentence_id[int(sentence.attrib['id'])])

test_document_ids = list(test_document_ids)

In [6]:
def get_text_without_pos_tag(text):
    words = text.split()

    for i in range(len(words)):
        words[i] = ''.join(words[i].split('\\')[:-1])
    
    return ' '.join(words)
    
def get_sentence(sentence_id):
    phrases = phrases_by_sentence_id[sentence_id]
    
    splitted_sentence = []
    for phrase in phrases:
        text = get_text_without_pos_tag(phrase.text)
        
        if 'id' in phrase.attrib:
            splitted_sentence.append(f'({text})[{phrase.attrib["id"]}]')
        else:
            splitted_sentence.append(text)
            
    sentence = ' '.join(splitted_sentence)
    return sentence

def get_document_text(document_id):
    sentence_ids = sentence_ids_by_document_id[document_id]
    sentences = [get_sentence(sentence_id) for sentence_id in sentence_ids]
    
    document_text = ''
    for sentence in sentences:
        document_text += sentence
        document_text += ' ' if sentence[-1] in punctuation else '\n'
        
    return document_text

# Get Baseline Chains

In [7]:
baseline_result_file_path = 'baseline/suherik_and_purwarianti/test_result.txt'

baseline_ufds = UFDS()

for m1, m2 in zip(pairs.m1_id, pairs.m2_id):
    baseline_ufds.init_id(m1, m2)
    
for line in open(baseline_result_file_path, 'r').readlines():
    line = line.split(', ')
    baseline_ufds.join(int(line[0]), int(line[1]))

baseline_chains = list(filter(lambda x: len(x) > 1, baseline_ufds.get_chain_list()))

# Get Predicted Chains

In [8]:
predicted_result_wo_sc_file_path = 'result/without_singleton_classifier.json'
predicted_result_w_sc_file_path = 'result/with_singleton_classifier.json'
predicted_result_w_label_sc_file_path = 'result/with_label_singleton_classifier.json'

wo_sc_chains = list(map(lambda chain: [markable[0] for markable in chain], json.load(open(predicted_result_wo_sc_file_path))))
w_sc_chains = list(map(lambda chain: [markable[0] for markable in chain], json.load(open(predicted_result_w_sc_file_path))))
w_label_sc_chains = list(map(lambda chain: [markable[0] for markable in chain], json.load(open(predicted_result_w_label_sc_file_path))))

# Analysis for Each Document

In [9]:
def get_chains_by_document_id(chains, document_id):
    filtered_chains = []
    
    for chain in chains:
        filtered_chain = [markable_id for markable_id in chain if document_id_by_markable_id[markable_id] == document_id]
        if len(filtered_chain) > 1:
            filtered_chains.append(filtered_chain)
    
    return filtered_chains

In [10]:
chains_list = [label_chains, baseline_chains, wo_sc_chains, w_sc_chains, w_label_sc_chains]

def get_printable_chains(chains):
    printable = ''
    
    for chain in chains:
        printable += str(chain)
        printable += '<br />'
    
    return printable
    
for document_id in test_document_ids:
    print(get_document_text(document_id))
    
    display(HTML(tabulate(
        [[get_printable_chains(get_chains_by_document_id(chains, document_id)) for chains in chains_list]], tablefmt='html')))
    
    print()

(PT Astra Otoparts Tbk)[1916] menjual kepemilikan saham (nya)[1917] di (PT Exedy Indonesia)[1918] (sebanyak 7.072 saham)[1919] dengan (harga)[1920] (Rp 1.835.000)[1921] per (saham)[1922] atau total (nya)[1923] (senilai Rp 12,977 miliar)[1924] kepada (Exedy Corporation)[1925] yang berkedudukan di (Jepang.)[1926] (Kami)[1927] telah menandatangani (perjanjian pengikatan)[1928] untuk penjualan (seluruh saham)[1929] milik (PT Astra Otoparts)[1930] di (PT Exedy Indonesia)[1931] kepada (Exedy Corporation,)[1932] kata (Sekretaris Perusahaan Astra Otoparts,)[1933] (Kartina Rahayu)[1934] di (Jakarta,)[1935] (Senin.)[1936] Menurut (dia,)[1937] (penjualan saham tersebut)[1938] akan berlaku efektif apabila syarat-syarat sebagaimana termuat dalam (perjanjian pengikatan jual beli saham)[1939] telah terpenuhi. (Dia)[1940] mengatakan (transaksi tersebut)[1941] tidak mengandung (unsur benturan kepentingan)[1942] dilihat dari sisi (direksi,)[1943] (komisaris)[1944] dan (pemegang saham utama.)[1945] Lagi 

NameError: name 'HTML' is not defined