In [3]:
import random

import pandas as pd
import numpy as np
import networkx as nx

from tqdm.notebook import tqdm
from argparse import ArgumentParser
from collections import namedtuple
from ast import literal_eval

import main  # to use functions accessing data

# Columns of CONLL file
CONLL_COLS = ['index',
              'sentence',
              'lemma_sentence',
              'upos_sentence',
              'xpos_sentence',
              'morph',
              'head_indices',
              'governance_relations',
              'secondary_relations',
              'extra_info']
ObservationClass = namedtuple("Observation", CONLL_COLS)

In [4]:
OBSERVATIONS = main.load_conll_dataset("ptb3-wsj-data/ptb3-wsj-dev.conllx",
                                       ObservationClass)
EDGES_DF = pd.read_csv("results-azure/bert-large-cased_pad60_2020-04-09-13-57/scores_bert-large-cased_pad60_2020-04-09-13-57.csv")

In [5]:
EXCLUDED_PUNCTUATION = ["", "'", "''", ",", ".", ";", "!", "?", ":", "``", "-LRB-", "-RRB-"]

def is_edge_to_ignore(edge, observation):
    is_d_punct = bool(observation.sentence[edge[1]-1] in EXCLUDED_PUNCTUATION)
    is_h_root = bool(edge[0] == 0)
    return is_d_punct or is_h_root

def score_observation(predicted_edges, observation, convert_from_0index=False):
    # get gold edges (1-indexed)
    gold_edges_list = list(zip(list(map(int, observation.head_indices)),
                               list(map(int, observation.index)),
                               observation.governance_relations))
    gold_edge_to_label = {(e[0], e[1]): e[2] for e in gold_edges_list
                          if not is_edge_to_ignore(e, observation)}
    # just the unlabeled edges
    gold_edges_set = {tuple(sorted(e)) for e in gold_edge_to_label.keys()}

    # note converting to 1-indexing
    k = 1 if convert_from_0index else 0
    predicted_edges_set = {tuple(sorted((x[0]+k, x[1]+k))) for x in predicted_edges}

    correct_edges = list(gold_edges_set.intersection(predicted_edges_set))
    incorrect_edges = list(predicted_edges_set.difference(gold_edges_set))
    num_correct = len(correct_edges)
    num_total = len(gold_edges_set)
    uuas = num_correct/float(num_total) if num_total != 0 else np.NaN
    return uuas

In [6]:
def Lk(ls):
    """Given a length sequence ls,
    Returns 
        L: the length sequence (sorted)
        k: a list of the counts of each length in L"""
    L = sorted(ls)
    k = [L.count(n) for n in range(L[-1]+1)]
    return L,k

def lengthmatch(observation):
    '''get gold edge set and length sequence'''
    # get gold edges (1-indexed)
    gold_edges_list = list(zip(list(map(int, observation.head_indices)),
                               list(map(int, observation.index)),
                               observation.governance_relations))
    gold_edge_to_label = {(e[0], e[1]): e[2] for e in gold_edges_list
                          if not is_edge_to_ignore(e, observation)}
    # just the unlabeled edges
    gold_edges_set = {tuple(sorted(e)) for e in gold_edge_to_label.keys()}
    lens = [e[1]-e[0] for e in gold_edges_set]
    return lens, gold_edges_set

In [7]:
# obs = OBSERVATIONS[13]

# lens, ges = lengthmatch(obs)
# print(lens)
# print(ges)

# G=nx.Graph()
# G.add_edges_from(ges)
# # G.remove_edges_from(ges)
# print(sorted(G.edges()))
# bool((2,1) in G.edges())
# print("Is a tree?",nx.is_forest(G))
# score_observation(G.edges(),obs)

In [8]:
def generate_lengthmatch(observation, make_tree=True, iterations=100):
    '''
    generate length-matched baseline of observation
    after algorithm described in
    https://cs.stackexchange.com/questions/116193
    '''
    lens, ges = lengthmatch(observation)
    print(f"len {len(observation[0])}", end=". ")
    # initialize a graph with the right number of nodes
    for iteration in range(iterations):

        T=nx.Graph()
        T.add_edges_from(ges)
        T.remove_edges_from(ges)
        L,k = Lk(lens)
        
        # for each edge length
        for l in set(L):
            V=T.nodes()
            E=T.edges()

            # generate set P of possible new edges of len l
            eplus = set(tuple(sorted((u,u+l))) for u in V if u+l in V and (u,u+l) not in E)
            eminus= set(tuple(sorted((u,u-l))) for u in V if u-l in V and (u,u-l) not in E)
            P = eplus.union(eminus)
            if len(P)==0:
                raise ValueError("ERROR: no possible edges")
            # sampling k[l] edges of length l from P
            additional_edges = random.sample(P,k[l])
            T.add_edges_from(additional_edges)

            if make_tree and not nx.is_forest(T):
                break
                
        if (make_tree and nx.is_tree(T)) or not make_tree:
            if make_tree: 
                print(f"success after {iteration} iterations")
            return score_observation(T.edges(),observation)
    print(f"FAILED (reached max iterations {iterations})")
    return np.NaN
        
# generate_lengthmatch(OBSERVATIONS[65],make_tree=True)

In [15]:
scores=[]
countobs=0
for i, obs in enumerate(tqdm(OBSERVATIONS)):
    print(str(i).ljust(4),end="")
    if 35 < len(obs[0]):
        countobs +=1
        scores.append(generate_lengthmatch(obs, iterations=15000))
    else:
        print(f"len {len(obs[0])}","(too short/long)")
failurecount = scores.count(np.NaN)

HBox(children=(FloatProgress(value=0.0, max=1700.0), HTML(value='')))

0   len 37. FAILED (reached max iterations 15000)
1   len 45. FAILED (reached max iterations 15000)
2   len 20 (too long)
3   len 37. success after 3994 iterations
4   len 23 (too long)
5   len 22 (too long)
6   len 17 (too long)
7   len 28 (too long)
8   len 34 (too long)
9   len 23 (too long)
10  len 30 (too long)
11  len 29 (too long)
12  len 23 (too long)
13  len 13 (too long)
14  len 20 (too long)
15  len 42. FAILED (reached max iterations 15000)
16  len 44. FAILED (reached max iterations 15000)
17  len 27 (too long)
18  len 26 (too long)
19  len 15 (too long)
20  len 9 (too long)
21  len 42. FAILED (reached max iterations 15000)
22  len 22 (too long)
23  len 17 (too long)
24  len 10 (too long)
25  len 22 (too long)
26  len 17 (too long)
27  len 14 (too long)
28  len 21 (too long)
29  len 18 (too long)
30  len 26 (too long)
31  len 8 (too long)
32  len 17 (too long)
33  len 27 (too long)
34  len 21 (too long)
35  len 32 (too long)
36  len 21 (too long)
37  len 9 (too long)
38  len

325 len 38. FAILED (reached max iterations 15000)
326 len 34 (too long)
327 len 27 (too long)
328 len 18 (too long)
329 len 8 (too long)
330 len 32 (too long)
331 len 18 (too long)
332 len 33 (too long)
333 len 36. success after 2473 iterations
334 len 39. FAILED (reached max iterations 15000)
335 len 4 (too long)
336 len 41. FAILED (reached max iterations 15000)
337 len 28 (too long)
338 len 41. success after 13848 iterations
339 len 20 (too long)
340 len 35 (too long)
341 len 18 (too long)
342 len 21 (too long)
343 len 14 (too long)
344 len 28 (too long)
345 len 20 (too long)
346 len 26 (too long)
347 len 27 (too long)
348 len 10 (too long)
349 len 35 (too long)
350 len 32 (too long)
351 len 27 (too long)
352 len 43. FAILED (reached max iterations 15000)
353 len 25 (too long)
354 len 26 (too long)
355 len 30 (too long)
356 len 23 (too long)
357 len 42. FAILED (reached max iterations 15000)
358 len 15 (too long)
359 len 31 (too long)
360 len 63. FAILED (reached max iterations 15000)
3

645 len 37. success after 1578 iterations
646 len 28 (too long)
647 len 32 (too long)
648 len 34 (too long)
649 len 51. FAILED (reached max iterations 15000)
650 len 25 (too long)
651 len 26 (too long)
652 len 17 (too long)
653 len 3 (too long)
654 len 41. FAILED (reached max iterations 15000)
655 len 32 (too long)
656 len 26 (too long)
657 len 43. FAILED (reached max iterations 15000)
658 len 21 (too long)
659 len 7 (too long)
660 len 21 (too long)
661 len 35 (too long)
662 len 31 (too long)
663 len 11 (too long)
664 len 31 (too long)
665 len 53. FAILED (reached max iterations 15000)
666 len 28 (too long)
667 len 24 (too long)
668 len 26 (too long)
669 len 13 (too long)
670 len 19 (too long)
671 len 24 (too long)
672 len 19 (too long)
673 len 31 (too long)
674 len 23 (too long)
675 len 26 (too long)
676 len 22 (too long)
677 len 27 (too long)
678 len 29 (too long)
679 len 19 (too long)
680 len 39. FAILED (reached max iterations 15000)
681 len 15 (too long)
682 len 22 (too long)
683 le

964 len 38. FAILED (reached max iterations 15000)
965 len 24 (too long)
966 len 20 (too long)
967 len 20 (too long)
968 len 8 (too long)
969 len 35 (too long)
970 len 37. success after 12310 iterations
971 len 24 (too long)
972 len 25 (too long)
973 len 14 (too long)
974 len 12 (too long)
975 len 2 (too long)
976 len 11 (too long)
977 len 25 (too long)
978 len 14 (too long)
979 len 26 (too long)
980 len 14 (too long)
981 len 30 (too long)
982 len 29 (too long)
983 len 5 (too long)
984 len 14 (too long)
985 len 15 (too long)
986 len 21 (too long)
987 len 7 (too long)
988 len 19 (too long)
989 len 2 (too long)
990 len 11 (too long)
991 len 6 (too long)
992 len 2 (too long)
993 len 12 (too long)
994 len 25 (too long)
995 len 6 (too long)
996 len 23 (too long)
997 len 39. FAILED (reached max iterations 15000)
998 len 2 (too long)
999 len 19 (too long)
1000len 33 (too long)
1001len 46. FAILED (reached max iterations 15000)
1002len 16 (too long)
1003len 12 (too long)
1004len 33 (too long)
10

1296len 36. FAILED (reached max iterations 15000)
1297len 8 (too long)
1298len 33 (too long)
1299len 27 (too long)
1300len 46. FAILED (reached max iterations 15000)
1301len 11 (too long)
1302len 31 (too long)
1303len 21 (too long)
1304len 25 (too long)
1305len 51. FAILED (reached max iterations 15000)
1306len 18 (too long)
1307len 32 (too long)
1308len 29 (too long)
1309len 22 (too long)
1310len 29 (too long)
1311len 26 (too long)
1312len 50. FAILED (reached max iterations 15000)
1313len 32 (too long)
1314len 21 (too long)
1315len 10 (too long)
1316len 23 (too long)
1317len 11 (too long)
1318len 32 (too long)
1319len 15 (too long)
1320len 50. FAILED (reached max iterations 15000)
1321len 25 (too long)
1322len 42. FAILED (reached max iterations 15000)
1323len 27 (too long)
1324len 37. FAILED (reached max iterations 15000)
1325len 8 (too long)
1326len 29 (too long)
1327len 34 (too long)
1328len 26 (too long)
1329len 18 (too long)
1330len 20 (too long)
1331len 20 (too long)
1332len 10 (to

1633len 46. FAILED (reached max iterations 15000)
1634len 30 (too long)
1635len 20 (too long)
1636len 15 (too long)
1637len 36. success after 1302 iterations
1638len 11 (too long)
1639len 20 (too long)
1640len 24 (too long)
1641len 21 (too long)
1642len 25 (too long)
1643len 22 (too long)
1644len 14 (too long)
1645len 17 (too long)
1646len 36. FAILED (reached max iterations 15000)
1647len 29 (too long)
1648len 27 (too long)
1649len 40. FAILED (reached max iterations 15000)
1650len 18 (too long)
1651len 24 (too long)
1652len 17 (too long)
1653len 32 (too long)
1654len 17 (too long)
1655len 40. FAILED (reached max iterations 15000)
1656len 16 (too long)
1657len 40. FAILED (reached max iterations 15000)
1658len 32 (too long)
1659len 25 (too long)
1660len 29 (too long)
1661len 27 (too long)
1662len 28 (too long)
1663len 18 (too long)
1664len 48. FAILED (reached max iterations 15000)
1665len 31 (too long)
1666len 28 (too long)
1667len 26 (too long)
1668len 33 (too long)
1669len 24 (too long

In [17]:
print(np.nanmean(scores), f", n={countobs-failurecount} (Failure rate: {failurecount}/{countobs})")

0.3548713256831176 , n=41 (Failure rate: 192/233)


with $14 \leq |s| \leq35$, and  1000 iterations for rejection sampling: 
> mean acc = 0.387, n=828   (Failure rate: 343/1171)


Now for 15000 iterations:

- with $14 \leq |s| \leq35$, mean acc = `0.372` , n=1116 (Failure rate: 55/1171)

- and with $2 < |s| \leq13$, mean acc = `0.525` , n=280 (Failure rate: 0/280)

- and with $35 < |s|$, mean acc = `0.355` , n=41 (Failure rate: 192/233)

- so overall mean acc = $(0.525 * 280+0.372 * 1116+0.355 * 41)/(280 + 1116 + 41) =$ `0.401`, with a failure rate of $(55 + 192) / ( 280 + 1116 + 41 ) = 247/1684 =$ `0.172`


For comparison, we can also get the score of the CPMI trees:

In [925]:
gold_edges = [literal_eval(EDGES_DF.at[i, "gold_edges"]) for i,row in EDGES_DF.iterrows()]
cpmi_edges = [literal_eval(EDGES_DF.at[i, "projective.edges.sum"]) for i,row in EDGES_DF.iterrows()]

cpmi_scores = [score_observation(cpmi_edges[i],OBSERVATIONS[i],convert_from_0index=True) 
          for i,_ in enumerate(OBSERVATIONS)]
np.nanmean(cpmi_scores)

0.48217553684114794