In [13]:
import numpy as np
import codecs
import argparse
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from nltk import Tree
import random
import math
import subprocess
import sys
from collections import deque, Counter, defaultdict
from scipy.sparse.csgraph import minimum_spanning_tree
import string

from tools import dependency, sentence_attentions
from tools.dependency_converter import DependencyConverter
from tqdm import tqdm

import pandas as pd
import statistics

%matplotlib inline

In [4]:
dev = True
if dev:

    conllu_file_name = '../graph-extraction/endev.conllu'
else:

    conllu_file_name = '../data/entest.conllu'

## Positional

In [5]:
dependency_rels = dependency.read_conllu(conllu_file_name,True)

In [6]:
def find_positional_baseline(dep_rels):
    
    mc_offset = dict()
    positional_offsets = defaultdict(list)
    
    
    for sent_rels in dep_rels:
        for k in sorted(sent_rels.keys()):
            for rel in sent_rels[k]:
                positional_offsets[k].append(rel[1]-rel[0])
    for k in sorted(positional_offsets.keys()):    
        mc_offset[k] = statistics.mode(positional_offsets[k])

    return mc_offset, positional_offsets

def uas_from_baseline(dep_rels, offset):
    retrived = defaultdict(int)
    total = defaultdict(int)
    for dep_rel in dep_rels:
        for rel_type, rel_pairs in dep_rel.items():
            for rel_pair in rel_pairs:
                total[rel_type] += 1
                if rel_pair[0] + offset[rel_type] == rel_pair[1]:
                    retrived[rel_type] += 1
    for k in sorted(retrived.keys()):
        if total[k] > 0:
            print(f"UAS for {k} : {retrived[k]/total[k]} (number of relations: {total[k]})")
        else:
            print(f"No relations for {k}")

In [7]:
pos_baseline, pos_data = find_positional_baseline(dependency_rels)
pos_baseline

{'adj-clause-d2p': -2,
 'adj-clause-p2d': 2,
 'adj-modifier-d2p': 1,
 'adj-modifier-p2d': -1,
 'adv-clause-d2p': -2,
 'adv-clause-p2d': 2,
 'adv-modifier-d2p': 1,
 'adv-modifier-p2d': -1,
 'all-d2p': 1,
 'all-p2d': -1,
 'apposition-d2p': -3,
 'apposition-p2d': 3,
 'auxiliary-d2p': 1,
 'auxiliary-p2d': -1,
 'clausal subject-d2p': -3,
 'clausal subject-p2d': 3,
 'clausal-d2p': -4,
 'clausal-p2d': 4,
 'compound-d2p': 1,
 'compound-p2d': -1,
 'conjunct-d2p': -2,
 'conjunct-p2d': 2,
 'determiner-d2p': 1,
 'determiner-p2d': -1,
 'i object-d2p': -1,
 'i object-p2d': 1,
 'noun-modifier-d2p': -3,
 'noun-modifier-p2d': 3,
 'num-modifier-d2p': 1,
 'num-modifier-p2d': -1,
 'object-d2p': -2,
 'object-p2d': 2,
 'other-d2p': 1,
 'other-p2d': -1,
 'punctuation-d2p': 1,
 'punctuation-p2d': -1,
 'subject-d2p': 1,
 'subject-p2d': -1}

In [8]:
uas_from_baseline(dependency_rels, pos_baseline)

UAS for adj-clause-d2p : 0.3517915309446254 (number of relations: 307)
UAS for adj-clause-p2d : 0.3517915309446254 (number of relations: 307)
UAS for adj-modifier-d2p : 0.7627020785219399 (number of relations: 1732)
UAS for adj-modifier-p2d : 0.7627020785219399 (number of relations: 1732)
UAS for adv-clause-d2p : 0.11320754716981132 (number of relations: 424)
UAS for adv-clause-p2d : 0.11320754716981132 (number of relations: 424)
UAS for adv-modifier-d2p : 0.45729813664596275 (number of relations: 1288)
UAS for adv-modifier-p2d : 0.45729813664596275 (number of relations: 1288)
UAS for all-d2p : 0.325623460591133 (number of relations: 25984)
UAS for all-p2d : 0.325623460591133 (number of relations: 25984)
UAS for apposition-d2p : 0.17272727272727273 (number of relations: 110)
UAS for apposition-p2d : 0.17272727272727273 (number of relations: 110)
UAS for auxiliary-d2p : 0.6005121638924455 (number of relations: 781)
UAS for auxiliary-p2d : 0.6005121638924455 (number of relations: 781)
UA

## POS baseline

In [11]:
dependency_rels = dependency.read_conllu_labeled(conllu_file_name)
dependency_rels = [DependencyConverter(sent_rel).convert(return_root=True) for sent_rel in dependency_rels]

In [40]:
dependency_pos_freq = defaultdict(lambda: defaultdict(int))
for sent_rels in dependency_rels:
    for dep, head, label, pos in sent_rels:
        if label != 'root':
            pos_pair = (pos, sent_rels[head][3])
            dependency_pos_freq[label][pos_pair] += 1

In [85]:
frame = pd.DataFrame.from_dict(dependency_pos_freq)
frame = frame.dropna(axis=0, how='all')
frame = frame / frame.sum(axis=1)[:,None]
frame.fillna(0, inplace=True)
frame[['nsubj','obj', 'aux', 'amod', 'det']]

Unnamed: 0,Unnamed: 1,nsubj,obj,aux,amod,det
PRON,VERB,0.736395,0.140306,0.0,0.0,0.0
PROPN,VERB,0.369863,0.171233,0.0,0.0,0.0
NOUN,VERB,0.135484,0.417339,0.0,0.0,0.0
PRON,AUX,0.938776,0.006803,0.0,0.0,0.0
NOUN,AUX,0.438017,0.013774,0.0,0.0,0.0
...,...,...,...,...,...,...
INTJ,VERB,0.000000,0.000000,0.0,0.0,0.0
INTJ,NOUN,0.000000,0.000000,0.0,0.0,0.0
INTJ,ADJ,0.000000,0.000000,0.0,0.0,0.0
INTJ,PROPN,0.000000,0.000000,0.0,0.0,0.0


In [88]:
frame['nsubj']['PRON']['VERB']

0.7363945578231292

In [59]:
frame

Unnamed: 0,Unnamed: 1,nsubj,aux,amod,det
PRON,VERB,0.476086,0.0,0.0,0.0
PROPN,VERB,0.059373,0.0,0.0,0.0
NOUN,VERB,0.184717,0.0,0.0,0.0
PRON,AUX,0.151732,0.0,0.0,0.0
NOUN,AUX,0.087411,0.0,0.0,0.0
AUX,VERB,0.003848,0.901408,0.0,0.0
ADJ,VERB,0.003848,0.0,0.001155,0.0
PROPN,AUX,0.012644,0.0,0.0,0.0
PRON,NOUN,0.003848,0.0,0.0,0.000358
NUM,AUX,0.001649,0.0,0.0,0.0
