In [None]:
import sys
from tqdm import tqdm
from collections import Counter

In [None]:
## Some hyperparameters

language = "english"
# language = "slovenian"
# language = "romanian"

# comparison = "within-1"
# comparison = "within-2"
comparison = "between"

In [None]:
if language=="english":
    # GUM (Georgetown University Multilayer) corpus from UD website
    corpus1_dir = "data/ud-en-gum/"
    train1 = corpus1_dir+"en_gum-ud-train.conllu"
    dev1 = corpus1_dir+"en_gum-ud-dev.conllu"

    # wsj corpus with different conventions
    # converted from Stanford dependencies (?)
    corpus2_dir = "data/wsj-DIFF-CONVENTIONS/"
    train2 = corpus2_dir+"train.conllu"
    dev2 = corpus2_dir+"dev.conllu"
    test2 = corpus2_dir+"test.conllu"

In [None]:
if language=="slovenian":
    # SSJ corpus
    corpus1_dir = "data/UD_Slovenian-SSJ/"
    train1 = corpus1_dir+"sl_ssj-ud-train.conllu"
    dev1 = corpus1_dir+"sl_ssj-ud-dev.conllu"

    # SST corpus
    corpus2_dir = "data/UD_Slovenian-SST/"
    train2 = corpus2_dir+"sl_sst-ud-train.conllu"
    dev2 = corpus2_dir+"sl_sst-ud-test.conllu"

In [None]:
if language=="romanian":
    # RRT corpus
    corpus1_dir = "data/UD_Romanian-RRT/"
    train1 = corpus1_dir+"ro_rrt-ud-train.conllu"
    dev1 = corpus1_dir+"ro_rrt-ud-dev.conllu"

    # Nonstandard corpus
    corpus2_dir = "data/UD_Romanian-Nonstandard/"
    train2 = corpus2_dir+"ro_nonstandard-ud-train.conllu"
    dev2 = corpus2_dir+"ro_nonstandard-ud-dev.conllu"

In [None]:
gum_relations = []
wsj_relations = []

In [None]:
def process_ud_data(ud_file):
    """
    Processes UD data into dictionaries of word pairs, relations, and sentences.
    """
    print(f"\nProcessing '{ud_file}' data file...")
    with open(ud_file) as infile:
        ud_lines = infile.readlines()
    # get list of sentences
    sentences = []
    sentence = []
    for line in ud_lines:
#         print(line)
        if "# sent_id = " in line:
            sentence.append(line.split(" sent_id = ")[1].strip())
        if "# text =" in line:
            sentence.append(line.split(" text =")[1].strip())
        elif line[0] == "#":
            continue
        elif len(line.strip()) == 0:
            sentences.append(sentence)
            sentence = []
        else:
            split = line.split("\t")
            sentence.append(split)

    pair_to_relations = {}
    pair_relation_to_sentences = {}
    for sentence in tqdm(sentences):
#         print("\nSENTENCE:\t",sentence)
        # each sentence is a list of lines split on tabs
        sentence_id = sentence[0]
        sentence_text = sentence[1]
#         print("Sentence ID:\t"+sentence_id)
#         print("Sentence Text:\t"+sentence_text)
        for line in sentence[2:]:
#             print("LINE:\t",line)
            word_idx = line[0]
#             print("word idx:\t",word_idx)
            word = line[1]
#             print("word:\t"+word)
            head_idx = int(line[6])
#             print("head idx:\t",head_idx)
#             head_idx+1 needed to match place in list
            head_word = sentence[head_idx+1][1] if head_idx != 0 else "#ROOT#"
#             print("head word:\t"+head_word)
            relation = line[7]
            if relation not in gum_relations:
                gum_relations.append(relation)
#             print("relation:\t"+relation)
            word_pair = (head_word, word)
#             print("word pair:\t",word_pair)
            if word_pair in pair_to_relations:
                pair_to_relations[word_pair].append(relation)
            else:
                pair_to_relations[word_pair] = [relation]
            pair_relation = (word_pair, relation)
            if pair_relation in pair_relation_to_sentences:
                pair_relation_to_sentences[pair_relation].append([sentence_id, sentence_text])
            else:
                pair_relation_to_sentences[pair_relation] = [[sentence_id, sentence_text]]
    """
    pair_to_relations dict:
        e.g. {("Need", "'ll"): ["aux"], ...}

    pair_relation_to_sentences dict:
        e.g. {(("funnel", "Large"), "amod"): ["Large funnel or strainer to hold filter"], ...}
    """
    print(f"\t{len(pair_to_relations):,} head-dependent:relation pairs")
    print(f"\t{len(pair_relation_to_sentences):,} head-dependent-relation:sentence triples")
    return pair_to_relations, pair_relation_to_sentences

In [None]:
def process_wsj_data(ud_file):
    """
    Processes UD data into dictionaries of word pairs, relations, and sentences.
    """
    print(f"\nProcessing '{ud_file}' data file...")
    with open(ud_file) as infile:
        ud_lines = infile.readlines()
    # get list of sentences
    sentences = []
    sentence = []
    for line in ud_lines:
        if "# sent_id = " in line:
            sentence.append(line.split(" sent_id = ")[1].strip())
        if "# text =" in line:
            sentence.append(line.split(" text = ")[1].strip())
        elif line[0] == "#":
            continue
        elif len(line.strip()) == 0:
            sentences.append(sentence)
            sentence = []
        else:
            split = line.split("\t")
            sentence.append(split)

    pair_to_relations = {}
    pair_relation_to_sentences = {}
    for sentence in tqdm(sentences):
#         print("\nSENTENCE:\t",sentence)
        # each sentence is a list of lines split on tabs
        sentence_id = sentences.index(sentence)
#         print("Sentence ID:\t"+str(sentence_id))
        sentence_text = []
        for line in sentence:
            word = line[1]
            sentence_text.append(word)
#         print("Sentence Text:\t"+" ".join(sentence_text))
        for line in sentence:
#             print("LINE:\t",line)
            word_idx = line[0]
#             print("word idx:\t",word_idx)
            word = line[1]
#             print("word:\t"+word)
            head_idx = int(line[6])
#             print("head idx:\t",head_idx)
#             head_idx+1 needed to match place in list
            head_word = sentence[head_idx-1][1] if head_idx != 0 else "#ROOT#"
#             print("head word:\t"+head_word)
            relation = line[7]
            if relation not in wsj_relations:
                wsj_relations.append(relation)
#             print("relation:\t"+relation)
            word_pair = (head_word, word)
#             print("word pair:\t",word_pair)
            if word_pair in pair_to_relations:
                pair_to_relations[word_pair].append(relation)
            else:
                pair_to_relations[word_pair] = [relation]
            pair_relation = (word_pair, relation)
            if pair_relation in pair_relation_to_sentences:
                pair_relation_to_sentences[pair_relation].append([sentence_id, " ".join(sentence_text)])
            else:
                pair_relation_to_sentences[pair_relation] = [[sentence_id, " ".join(sentence_text)]]
    """
    pair_to_relations dict:
        e.g. {("Need", "'ll"): ["aux"], ...}

    pair_relation_to_sentences dict:
        e.g. {(("funnel", "Large"), "amod"): ["Large funnel or strainer to hold filter"], ...}
    """
    print(f"\t{len(pair_to_relations):,} head-dependent:relation pairs")
    print(f"\t{len(pair_relation_to_sentences):,} head-dependent-relation:sentence triples")
    return pair_to_relations, pair_relation_to_sentences

In [None]:
if comparison == "within-1":
    train_file = train1
    dev_file = dev1
    train_output = language+"_train_mismatches_within_corpus1"+".tsv"
    dev_output = language+"_dev_mismatches_within_corpus1"+".tsv"
    # get list for training data
    train_list, train_sentences = process_ud_data(train_file)
    # get list for testing data
    dev_list, dev_sentences = process_ud_data(dev_file)
    
elif comparison == "within-2":
    train_file = train2
    dev_file = dev2
    train_output = language+"_train_mismatches_within_corpus2"+".tsv"
    dev_output = language+"_dev_mismatches_within_corpus_2"+".tsv"
    # get list for training data
    train_list, train_sentences = process_ud_data(train_file)
    # get list for testing data
    dev_list, dev_sentences = process_ud_data(dev_file)
    ## FOR ENGLISH ONLY (?)
    # get list for training data
#     train_list, train_sentences = process_wsj_data(train_file)
    # get list for testing data
#     dev_list, dev_sentences = process_wsj_data(dev_file)

    
elif comparison == "between":
#     train_file = wsj_train
#     dev_file = gum_train
    train_file = train1
    dev_file = train2
    train_output = language+"_train_mismatches_between_corpora"+".tsv"
    dev_output = language+"_dev_mismatches_between_corpora"+".tsv"
    # get list for training data
    train_list, train_sentences = process_ud_data(train_file)
    # get list for testing data
    # need to keep process_wsj_data for now for English data
    dev_list, dev_sentences = process_wsj_data(dev_file)
#     dev_list, dev_sentences = process_ud_data(dev_file)

In [None]:
# for triple in train_sentences:
# #     print(triple, train_sentences[triple])
#     for sentence in train_sentences[triple]:
#         print(triple, sentence)

In [None]:
# for pair in dev_list:
#     print(pair, dev_list[pair])

In [None]:
# for triple in dev_sentences:
# #     print(triple, dev_sentences[triple])
#     for sentence in dev_sentences[triple]:
#         print(triple, sentence)

In [None]:
# compare the two lists
dev_mismatches = {}
train_mismatches = {}
for dev_pair in dev_list.keys():
    # If the (head, dependent) pair in dev is in train
    if dev_pair in train_list.keys():
        # get the relations for that pair in dev
        dev_relations = dev_list[dev_pair]
#         print("dev pair:\t\t",dev_pair)
#         print("dev relations:\t\t",dev_relations)
        # and in train
        train_relations = train_list[dev_pair]
#         print("train relations:\t",train_relations)
        # TODO: decide which ones we actually care about...
        # get the relations in dev NOT in train
        not_in_train = [x for x in dev_relations if x not in set(train_relations)]
        # and get the relations in train NOT in dev
        not_in_dev = [x for x in train_relations if x not in set(dev_relations)]
#         print("relations not in dev:\t",not_in_train)
#         print("relations not in train:\t",not_in_dev)
        # if there are relations not in dev/train, add entry to dev_mismatches or train_mismatches for that pair-relation combo
        if len(not_in_train) != 0:
#             dev_mismatches[dev_pair] = list(set(not_in_train))
            dev_mismatches[dev_pair] = not_in_train
#             print("dev_mismatches:\t",dev_mismatches)
        if len(not_in_dev) != 0:
#             train_mismatches[dev_pair] = list(set(not_in_dev))
            train_mismatches[dev_pair] = not_in_dev

In [None]:
print("")
print(f"{len(dev_mismatches)} pairs with a relation in dev but not in train")
print(f"{len(train_mismatches)} pairs with a relation in train but not in dev")

In [None]:
# dev_mismatches

In [None]:
# train_mismatches

In [None]:
def generate_human_readable_output(filename, mismatches, sentences, this_data, other_data):
    """
    Generates human-readable output file for evaluation.
    
    Args:
        filename: the output filename
        
        mismatches: the dictionary of word pairs with a relation in one file but not the other
        
        sentences: the dictionary of word pairs, their relations, and the sentences they're found in
    """
    conversion_dict = {}
    with open(filename, "w") as output:
        header = ("ID" + "\t" +
                  "SENTENCE" + "\t" + 
                  "HEAD WORD" + "\t" + 
                  "DEPENDENT WORD" + "\t" + 
                  "RELATION" + "\t" + 
                  "TOP RELATION IN THIS DATA" + "\t" +
                  "COUNT OF TOP RELATION IN THIS DATA" + "\t" +
                  "PROPORTION OF TOP RELATION IN THIS DATA" + "\t"
                  "TOP RELATION IN OTHER DATA" + "\t" +
                  "COUNT OF TOP RELATION IN OTHER DATA" + "\t" +
                  "PROPORTION OF TOP RELATION IN OTHER DATA" + "\n")
        output.write(header)
        for pair in mismatches:
#             print("word pair:\t", pair)
            head_word = pair[0]
            dependent_word = pair[1]
            # relations for this pair in this partition / corpus
            these_relations = Counter(this_data[pair])
            # relations for this pair in other partition / corpus
            other_relations = Counter(other_data[pair])
#             print("other_relations:\t", sum(Counter(other_relations).values()))
            for relation in list(set(mismatches[pair])):
                triple = (pair, relation)
                sentence_ids = []
                sentence_texts = []
                for sentence in sentences[triple]:
                    sentence_id = str(sentence[0])
                    sentence_text = sentence[1]
                    sentence_ids.append(sentence_id)
                    sentence_texts.append(sentence_text)
                # get most common label/count for this data and other data
                most_common_label = these_relations.most_common(1)[0][0]
                most_common_count = these_relations.most_common(1)[0][1]
                sum_these_relations = sum(these_relations.values())
                most_common_label_other = other_relations.most_common(1)[0][0]
                most_common_count_other = other_relations.most_common(1)[0][1]
                sum_other_relations = sum(other_relations.values())
                line = ("('"+("', '").join(sentence_ids)+"')" + "\t" +
                        "('"+("', '").join(sentence_texts)+"')" + "\t" +
                        head_word + "\t" +
                        dependent_word + "\t" +
                        relation + "\t" +
                        most_common_label + "\t" +
                        str(most_common_count) + "\t" +
                        str(most_common_count/sum_these_relations) + "\t" +
                        most_common_label_other + "\t" +
                        str(most_common_count_other) + "\t" +
                        str(most_common_count_other/sum_other_relations)+ "\n"    
                        )
                key = (head_word, dependent_word, relation)
                value = {"most_common_label": most_common_label, 
                         "most_common_count": most_common_count, 
                         "proportion_these": most_common_count/sum_these_relations, 
                         "most_common_label_other": most_common_label_other, 
                         "most_common_count_other": most_common_count_other, 
                         "proportion_other": most_common_count_other/sum_other_relations}
                conversion_dict[key] = value
                print("LINE:\t"+line)
                output.write(line)
    return conversion_dict

In [None]:
dev_conversion = generate_human_readable_output(dev_output, dev_mismatches, dev_sentences, dev_list, train_list)

In [None]:
train_conversion = generate_human_readable_output(train_output, train_mismatches, train_sentences, train_list, dev_list)

In [None]:
dev_conversion[('role', 'played', 'acl')]

In [None]:
train_conversion

In [None]:
def apply_conversions(input_file, output_file, conversion_dictionary):
    with open(input_file) as infile:
        lines = infile.readlines()
    sentences = []
    sentence = []
    metadata = []
    for line in lines:
        if line[0] == "#":
            metadata.append(line)
            if "# text =" in line:
                sentence.append(line.split(" text = ")[1].strip())
        elif len(line.strip()) == 0:
            item = [metadata]
            item.append(sentence)
            sentences.append(item)
            metadata = []
            sentence = []
        else:
            split = line.split("\t")
            sentence.append(split)
    for sentence in sentences:
        metadata = sentence[0]
        the_rest = sentence[1]
        sentence_text = sentence[1][0]
        for sent in sentence[1][1:]:
#             print(sent)
            idx = sent[0]
            word = sent[1]
            lemma = sent[2]
            head_idx = int(sent[6])
            head_word = sentence[1][head_idx][1] if head_idx != 0 else "#ROOT#"
            relation = sent[7]
            triple = (head_word, word, relation)
            if conversion_dictionary.get(triple):
#                 print("\n")
#                 print(triple)
#                 print(conversion_dictionary.get(triple))
                new_relation = conversion_dictionary[triple]["most_common_label_other"]
#                 print("OLD:\t"+relation)
#                 print("NEW:\t"+new_relation)
#                 print(sent)
                sent[7] = new_relation
#                 print(sent)
            else:
                continue
#         with open(output_file, "a") as outfile:
#             for line in metadata:
#                 outfile.write(line)
#             for sent in sentence[1][1:]:
# #                 print(sent)
#                 outfile.write("\t".join(sent))
#             outfile.write("\n")
            

In [None]:
# wsj conversion needed because English WSJ data does not have metadata lines, so it messes up the indexing

In [None]:
def apply_wsj_conversions(input_file, conversion_dictionary, threshold):
    
    output_file = input_file+"_converted.conllu"
    with open(input_file) as infile:
        lines = infile.readlines()
    sentences = []
    sentence = []
    metadata = []
    for line in lines:
        if line[0] == "#":
            metadata.append(line)
            if "# text =" in line:
                sentence.append(line.split(" text = ")[1].strip())
        elif len(line.strip()) == 0:
            item = [metadata]
            item.append(sentence)
            sentences.append(item)
            metadata = []
            sentence = []
        else:
            split = line.split("\t")
            sentence.append(split)
    for sentence in sentences:
        metadata = sentence[0]
        the_rest = sentence[1]
        sentence_text = sentence[1][0]
        for sent in sentence[1]:
#             print(sent)
            idx = sent[0]
            word = sent[1]
            lemma = sent[2]
            head_idx = int(sent[6])
            head_word = sentence[1][head_idx-1][1] if head_idx != 0 else "#ROOT#"
            relation = sent[7]
            triple = (head_word, word, relation)
#             print(triple)
            if conversion_dictionary.get(triple):
#                 print("\n")
#                 print(triple)
#                 print(conversion_dictionary.get(triple))
                new_relation = conversion_dictionary[triple]["most_common_label_other"]
                other_proportion = conversion_dictionary[triple]["proportion_other"]
#                 print("OLD:\t"+relation)
#                 print("NEW:\t"+new_relation)
#                 print(sent)
#                 print(other_proportion)
                if other_proportion >= threshold:
                    sent[7] = new_relation
#                 print(sent)
            else:
                continue
#         with open(output_file, "a") as outfile:
#             for line in metadata:
#                 outfile.write(line)
#             for sent in sentence[1]:
# #                 print(sent)
#                 outfile.write("\t".join(sent))
#             outfile.write("\n")
            

In [None]:
converted_train_file = train_file+"_converted.conllu"
converted_dev_file = dev_file+"_converted.conllu"
print(converted_train_file)

In [None]:
apply_conversions(train_file, converted_train_file, train_conversion)

In [None]:
apply_wsj_conversions(dev_file, dev_conversion, 0.5)