In [6]:
import random

import pandas as pd
import numpy as np
import networkx as nx

from tqdm.notebook import tqdm
from argparse import ArgumentParser
from collections import namedtuple
from ast import literal_eval

import main  # to use functions accessing data

# Columns of CONLL file
CONLL_COLS = ['index',
              'sentence',
              'lemma_sentence',
              'upos_sentence',
              'xpos_sentence',
              'morph',
              'head_indices',
              'governance_relations',
              'secondary_relations',
              'extra_info']
ObservationClass = namedtuple("Observation", CONLL_COLS)

In [None]:
OBSERVATIONS = main.load_conll_dataset("../ptb3-wsj-data/ptb3-wsj-dev.conllx",
                                       ObservationClass)
EDGES_DF = pd.read_csv("../results-azure/bert-large-cased_pad60_2020-04-09-13-57/scores_bert-large-cased_pad60_2020-04-09-13-57.csv")

In [924]:
EXCLUDED_PUNCTUATION = ["", "'", "''", ",", ".", ";", "!", "?", ":", "``", "-LRB-", "-RRB-"]

def is_edge_to_ignore(edge, observation):
    is_d_punct = bool(observation.sentence[edge[1]-1] in EXCLUDED_PUNCTUATION)
    is_h_root = bool(edge[0] == 0)
    return is_d_punct or is_h_root

def score_observation(predicted_edges, observation, convert_from_0index=False):
    # get gold edges (1-indexed)
    gold_edges_list = list(zip(list(map(int, observation.head_indices)),
                               list(map(int, observation.index)),
                               observation.governance_relations))
    gold_edge_to_label = {(e[0], e[1]): e[2] for e in gold_edges_list
                          if not is_edge_to_ignore(e, observation)}
    # just the unlabeled edges
    gold_edges_set = {tuple(sorted(e)) for e in gold_edge_to_label.keys()}

    # note converting to 1-indexing
    k = 1 if convert_from_0index else 0
    predicted_edges_set = {tuple(sorted((x[0]+k, x[1]+k))) for x in predicted_edges}

    correct_edges = list(gold_edges_set.intersection(predicted_edges_set))
    incorrect_edges = list(predicted_edges_set.difference(gold_edges_set))
    num_correct = len(correct_edges)
    num_total = len(gold_edges_set)
    uuas = num_correct/float(num_total) if num_total != 0 else np.NaN
    return uuas

In [None]:
def Lk(ls):
    """Given a length sequence ls,
    Returns 
        L: the length sequence (sorted)
        k: a list of the counts of each length in L"""
    L = sorted(ls)
    k = [L.count(n) for n in range(L[-1]+1)]
    return L,k

def lengthmatch(observation):
    '''get gold edge set and length sequence'''
    # get gold edges (1-indexed)
    gold_edges_list = list(zip(list(map(int, observation.head_indices)),
                               list(map(int, observation.index)),
                               observation.governance_relations))
    gold_edge_to_label = {(e[0], e[1]): e[2] for e in gold_edges_list
                          if not is_edge_to_ignore(e, observation)}
    # just the unlabeled edges
    gold_edges_set = {tuple(sorted(e)) for e in gold_edge_to_label.keys()}
    lens = [e[1]-e[0] for e in gold_edges_set]
    return lens, gold_edges_set

In [None]:
# obs = OBSERVATIONS[13]

# lens, ges = lengthmatch(obs)
# print(lens)
# print(ges)

# G=nx.Graph()
# G.add_edges_from(ges)
# # G.remove_edges_from(ges)
# print(sorted(G.edges()))
# bool((2,1) in G.edges())
# print("Is a tree?",nx.is_forest(G))
# score_observation(G.edges(),obs)

In [None]:
def generate_lengthmatch(observation, make_tree=True, iterations=100):
    '''
    generate length-matched baseline of observation
    after algorithm described in
    https://cs.stackexchange.com/questions/116193
    '''
    lens, ges = lengthmatch(observation)
    print(f"len {len(observation[0])}", end=". ")
    # initialize a graph with the right number of nodes
    for iteration in range(iterations):

        T=nx.Graph()
        T.add_edges_from(ges)
        T.remove_edges_from(ges)
        L,k = Lk(lens)
        
        # for each edge length
        for l in set(L):
            V=T.nodes()
            E=T.edges()

            # generate set P of possible new edges of len l
            eplus = set(tuple(sorted((u,u+l))) for u in V if u+l in V and (u,u+l) not in E)
            eminus= set(tuple(sorted((u,u-l))) for u in V if u-l in V and (u,u-l) not in E)
            P = eplus.union(eminus)
            if len(P)==0:
                raise ValueError("ERROR: no possible edges")
            # sampling k[l] edges of length l from P
            additional_edges = random.sample(P,k[l])
            T.add_edges_from(additional_edges)

            if make_tree and not nx.is_forest(T):
                break
                
        if (make_tree and nx.is_tree(T)) or not make_tree:
            if make_tree: 
                print(f"success after {iteration} iterations")
            return score_observation(T.edges(),observation)
    print(f"FAILED (reached max iterations {iterations})")
    return np.NaN
        
# generate_lengthmatch(OBSERVATIONS[65],make_tree=True)

In [None]:
scores=[]
countobs=0
for i, obs in enumerate(tqdm(OBSERVATIONS)):
    print(str(i).ljust(4),end="")
    if 14 <= len(obs[0]) <= 35:
        countobs +=1
        scores.append(generate_lengthmatch(obs, iterations=15000))
    else:
        print(f"len {len(obs[0])}","(too long)")
failurecount = scores.count(np.NaN)

with $14 \leq |s| \leq35$, and 1000 iterations for rejection sampling: 
> mean avg = 0.387, n=828 (Failure rate: 343/1171)

In [None]:
print(np.nanmean(scores), f", n={countobs-failurecount} (Failure rate: {failurecount}/{countobs})")

For comparison, we can also get the score of the CPMI trees:

In [925]:
gold_edges = [literal_eval(EDGES_DF.at[i, "gold_edges"]) for i,row in EDGES_DF.iterrows()]
cpmi_edges = [literal_eval(EDGES_DF.at[i, "projective.edges.sum"]) for i,row in EDGES_DF.iterrows()]

cpmi_scores = [score_observation(cpmi_edges[i],OBSERVATIONS[i],convert_from_0index=True) 
          for i,_ in enumerate(OBSERVATIONS)]
np.nanmean(cpmi_scores)

0.48217553684114794