In [1]:
from pprint import pprint

def test_oracle(codes, crels, orcl_fact, verbose=False):
    
    crels = set(crels)
    if verbose:
        prn_fun = lambda s="": print(s)
    else:
        prn_fun = lambda s="": None
    
    stack = Stack(False)
    stack.push(ROOT)
    parser = Parser(stack)
    oracle = orcl_fact(crels, parser)

    prn_fun("DEPS")
    for crel in sorted(crels):
        prn_fun("\t" + str(crel))
    prn_fun()

    PAD = 20
    LINE = PAD + len(ROOT) + 2 * len(codes) + 1

    for buffer in codes:
        prn_fun("-" * LINE)
        prn_fun(buffer)
        prn_fun("-" * LINE)

        while True:
            tos = stack.tos()
            if not oracle.consult(tos, buffer):
                prn_fun(parser.actions[-1].ljust(PAD) + " || STACK : " + str(stack))
                break

            prn_fun(parser.actions[-1].ljust(PAD) + " || STACK : " + str(stack))
            if stack.len() == 0:
                prn_fun("Empty stack, stopping")
                break

    prn_fun()
    prn_fun("*" * LINE)
    prn_fun("Stack")
    prn_fun("\t" + str(stack))
    deps = parser.get_dependencies()
    prn_fun("DEPS Actual")
    for crel in sorted(crels):
        prn_fun("\t" + str(crel))
    prn_fun("DEPS Pred")
    for dep in sorted(deps):
        prn_fun("\t" + str(dep))
    prn_fun("Actions")
    for a in parser.actions:
        prn_fun("\t" + a)
    prn_fun()
    prn_fun("Ordered Match?    " + str(set(deps) == crels))

    ndeps = norm_arcs(deps)
    ncrels = norm_arcs(crels)
    diff = (ndeps - ncrels).union(ncrels - ndeps)
    success = (len(diff) == 0)
    prn_fun("Un Ordered Match? " + str(success))
    if diff:
        prn_fun(diff)
    return success

In [2]:
class Stack(object):
    def __init__(self, verbose=False):    
        self.stack = []
        self.verbose = verbose
    
    def tos(self):
        if self.len() == 0:
            return None
        #assert self.len() > 0, "Can't peek when stack is empty"
        return self.stack[-1]
    
    def pop(self):
        assert self.len() > 0, "Can't pop when stack is empty"
        item = self.stack.pop()
        if self.verbose:
            print("POPPING: %s" % item)
            print("LEN:     %i" % len(self.stack))
        return item
    
    def push(self, item):
        self.stack.append(item)
        if self.verbose:
            print("PUSHING: %s" % item)
            print("LEN:     %i" % len(self.stack))
    
    def len(self):
        return len(self.stack)

    def contains(self, item):
        return item in self.stack
    
    def __repr__(self):
        return "|".join(self.stack)

In [3]:
ROOT = "root"

def norm_arc(arc):
    return tuple(sorted(arc))

def norm_arcs(arcs):
    return set(map(norm_arc, arcs))

class Parser(object):
    def __init__(self, stack):
        self.stack = stack
        self.arcs = []
        self.normed_arcs = set()
        # nodes with heads
        self.children = set()
        self.actions = []
        
    def get_dependencies(self):
        return [(l,r) for (l,r) in self.arcs if r != ROOT and l != ROOT]
        
    def left_arc(self, buffer):
        tos = self.stack.pop()
        #Pre-condition
        #assert self.has_head(tos) == False
        arc = (tos,buffer)
        n_arc = norm_arc(arc)
        assert n_arc not in self.normed_arcs, "Arc already processed %s" % (n_arc)
        self.arcs.append(arc)
        self.normed_arcs.add(arc)
        self.children.add(tos)
        self.actions.append("L ARC   : " + tos + "->" + buffer)
        
    def right_arc(self, buffer):
        tos = self.stack.tos()
        #normalize arc
        arc = (buffer,tos)
        n_arc = norm_arc(arc)
        assert n_arc not in self.normed_arcs, "Arc already processed %s" % (n_arc)
        self.arcs.append(arc)
        self.normed_arcs.add(n_arc)
        self.actions.append("R ARC   : " + tos + "<-" + buffer)
        self.children.add(buffer)
        self.stack.push(buffer)
        
    def reduce(self):
        tos = self.stack.pop()
        #assert self.has_head(tos) == True
        self.actions.append("REDUCE  : Pop  %s" % tos)
        
    def shift(self, buffer):
        self.stack.push(buffer)
        self.actions.append("SHIFT   : Push %s" % buffer)
    
    def skip(self, buffer):
        self.actions.append("SKIP    : item %s" % buffer)
    
    def has_head(self, item):
        return item in self.children
    
    def in_stack(self, item):
        return self.stack.contains(item)

In [4]:
from collections import defaultdict

SHIFT = "Shift"
REDUCE = "Reduce"
LARC = "LArc"
RARC = "Rarc"
SKIP = "Skip"

class Oracle(object):
    
    def __init__(self, crels, parser):
        self.parser = parser
        self.crels = norm_arcs(crels)
        self.mapping = self.build_mappings(crels)
    
    def build_mappings(self, pairs):
        mapping = defaultdict(set)
        for c,res in pairs:
            mapping[c].add(res)
            mapping[res].add(c)
        return mapping

    def cont(self, action):
        # continue parsing if REDUCE or LARC
        return action in (REDUCE,LARC)
    
    def remove_relation(self, a,b):
        self.mapping[a].remove(b)
        if len(self.mapping[a]) == 0:
            del self.mapping[a]
        self.mapping[b].remove(a)
        if len(self.mapping[b]) == 0:
            del self.mapping[b]
    
    def consult(self, tos, buffer):
        """
        Performs optimal decision for parser
        If true, continue processing, else Consume Buffer
        """
        parser = self.parser
        a,b = norm_arc((tos, buffer))
        if (a,b) in self.crels:
            # TOS has arcs remaining? If so, we need RARC, else LARC
            if len(self.mapping[tos]) == 1:
                parser.left_arc(buffer)
                self.remove_relation(tos, buffer)
                return self.cont(LARC)
            else:
                parser.right_arc(buffer)
                self.remove_relation(tos, buffer)
                return self.cont(RARC)
        else:
            if buffer not in self.mapping:
                parser.skip(buffer)
                return self.cont(SKIP)
            # If the buffer has relations further down in the stack, we need to POP the TOS
            for item in self.mapping[buffer]:
                if item == tos:
                    continue
                if parser.in_stack(item):
                    parser.reduce()
                    return self.cont(REDUCE)
            #end for
            #ELSE
            parser.shift(buffer)
            return self.cont(SHIFT)

In [5]:
test_pairs = []

test_pairs.append([
    ("A","B"),
])
test_pairs.append([
    ("A","B"),
    ("B","C"),
])
#C->B->A
test_pairs.append([
    ("C","B"),
    ("B","A"),
])
test_pairs.append([
    ("A","C"),
    ("B","C"),
])
test_pairs.append([
    ("A","B"),
    ("C","B"),
])
test_pairs.append([
    ("B","A"),
    ("B","C"),
])
test_pairs.append([
    ("A","C"),
    ("C","B"),
])

# Hard - has to flip relation
test_pairs.append([
    ("A","D"),
    ("D","B"),
    ("B","C"),
])
test_pairs.append([
    ("D","A"),
    ("D","B"),
    ("B","C"),
])
test_pairs.append([
    ("D","A"),
    ("B","D"),
    ("B","C"),
])

test_pairs.append([
    ("A","E"),
    ("E","B"),
    ("B","D"),
    ("D","C"),
])
test_pairs.append([
    ("A","D"),
    ("D","B"),
    ("B","C"),
    ("A", "F"),
    ("A", "E"),
])

test_pairs.append([
    ("A","D"),
    ("D","B"),
    ("B","C"),
    ("A", "F"),
    ("E", "F"),
])

oracle_fact = Oracle
for pairs in test_pairs:
    try:
        success = test_oracle("ABCDEF", pairs, oracle_fact, verbose=False)
    except:
        success = False
        
    if not success:
        print("Error for relations:")
        pprint(pairs)
        print()
        success = test_oracle("ABCDEF", pairs, oracle_fact, verbose=True)

## Visualize Parse for Tricker Graphs

### <span style="color:red">Doesn't Handle Cycles</span>
- So we remove the condition about only having a single parent

In [6]:
#[('1', '3'), ('1', '50'), ('3', '50')]
#['50', '1', '3']
pairs =[
    ("B","A"),
    ("B","C"),
    ("C","A"),
]
test_oracle("ABCDEF", pairs, Oracle, verbose=True)

DEPS
	('B', 'A')
	('B', 'C')
	('C', 'A')

-------------------------------------
A
-------------------------------------
SHIFT   : Push A     || STACK : root|A
-------------------------------------
B
-------------------------------------
R ARC   : A<-B       || STACK : root|A|B
-------------------------------------
C
-------------------------------------
L ARC   : B->C       || STACK : root|A
L ARC   : A->C       || STACK : root
SKIP    : item C     || STACK : root
-------------------------------------
D
-------------------------------------
SKIP    : item D     || STACK : root
-------------------------------------
E
-------------------------------------
SKIP    : item E     || STACK : root
-------------------------------------
F
-------------------------------------
SKIP    : item F     || STACK : root

*************************************
Stack
	root
DEPS Actual
	('B', 'A')
	('B', 'C')
	('C', 'A')
DEPS Pred
	('A', 'C')
	('B', 'A')
	('B', 'C')
Actions
	SHIFT   : Push A
	R ARC   : A<-B

True

In [7]:
pairs =[
    ("A","D"),
    ("D","B"),
    ("B","C"),
]
test_oracle("ABCDEF", pairs, Oracle, verbose=True)

DEPS
	('A', 'D')
	('B', 'C')
	('D', 'B')

-------------------------------------
A
-------------------------------------
SHIFT   : Push A     || STACK : root|A
-------------------------------------
B
-------------------------------------
SHIFT   : Push B     || STACK : root|A|B
-------------------------------------
C
-------------------------------------
R ARC   : B<-C       || STACK : root|A|B|C
-------------------------------------
D
-------------------------------------
REDUCE  : Pop  C     || STACK : root|A|B
L ARC   : B->D       || STACK : root|A
L ARC   : A->D       || STACK : root
SKIP    : item D     || STACK : root
-------------------------------------
E
-------------------------------------
SKIP    : item E     || STACK : root
-------------------------------------
F
-------------------------------------
SKIP    : item F     || STACK : root

*************************************
Stack
	root
DEPS Actual
	('A', 'D')
	('B', 'C')
	('D', 'B')
DEPS Pred
	('A', 'D')
	('B', 'D')
	('C',

True

## Non Projective Parse Should Fail Test

In [8]:
pairs =[
    ("A","C"),
    ("B","E"),
]
try:
    success = test_oracle("ABCDEF", pairs, Oracle, verbose=True)
except Exception as e:
    success = False
    raise e
assert success == False

DEPS
	('A', 'C')
	('B', 'E')

-------------------------------------
A
-------------------------------------
SHIFT   : Push A     || STACK : root|A
-------------------------------------
B
-------------------------------------
SHIFT   : Push B     || STACK : root|A|B
-------------------------------------
C
-------------------------------------
REDUCE  : Pop  B     || STACK : root|A
L ARC   : A->C       || STACK : root
SKIP    : item C     || STACK : root
-------------------------------------
D
-------------------------------------
SKIP    : item D     || STACK : root
-------------------------------------
E
-------------------------------------
SHIFT   : Push E     || STACK : root|E
-------------------------------------
F
-------------------------------------
SKIP    : item F     || STACK : root|E

*************************************
Stack
	root|E
DEPS Actual
	('A', 'C')
	('B', 'E')
DEPS Pred
	('A', 'C')
Actions
	SHIFT   : Push A
	SHIFT   : Push B
	REDUCE  : Pop  B
	L ARC   : A->C
	SKIP

## Test on Real Causal Relations (Limit to 2 or More Relations in a Sentence)

In [9]:
def normalize_cr(cr):
    return tuple(cr.replace("Causer:","").replace("Result:","").split("->"))

In [10]:
normalize_cr('Causer:14->Result:50')

('14', '50')

In [11]:
import pickle 

training_pickled = "/Users/simon.hughes/Google Drive/Phd/Data/CoralBleaching/Thesis_Dataset/training.pl"
with open(training_pickled, "rb+") as f:
    tagged_essays = pickle.load(f)
len(tagged_essays)

902

In [12]:
from collections import defaultdict

tag_freq = defaultdict(int)
unique_words = set()
for essay in tagged_essays:
    for sentence in essay.sentences:
        for word, tags in sentence:
            unique_words.add(word)
            for tag in tags:
                tag_freq[tag] += 1

EMPTY_TAG = "Empty"
#TODO - don't ignore Anaphor, other and rhetoricals here
cr_tags = list((t for t in tag_freq.keys() if ( "->" in t) and not "Anaphor" in t and not "other" in t and not "rhetorical" in t))

In [13]:
from pprint import pprint

relations = []
unq_cr_tags = set(cr_tags)
skipped_sent = 0
skipped_crels = 0
num_sents = 0
for essay_ix, essay in enumerate(tagged_essays):
    for sent_ix, taggged_sentence in enumerate(essay.sentences):
        num_sents += 1
        tag_seq = []
        un_tags = set()
        crel2tags = defaultdict(set)
        def add_tag(tag, crel):
            if tag not in un_tags:
                tag_seq.append(tag)
                un_tags.add(tag)
            crel2tags[crel].add(tag)
        
        has_causal = False
        for i, (wd,tags) in enumerate(taggged_sentence):
            csl = unq_cr_tags.intersection(tags)
            if not csl:
                continue
            has_causal = True
            for crel in csl:
                l_causer, r_effect = crel.split("->")
                l,r = normalize_cr(crel)
                if l_causer in tags:
                    add_tag(l, crel)
                if r_effect in tags:
                    add_tag(r, crel)                
                if l in tags:
                    add_tag(l, crel)
                if r in tags:
                    add_tag(r, crel)
        
        # Don't count sentences without any relations as skipped
        if not has_causal:
            continue
        
        supported_causal = set()
        supported_codes = set()
        for crel, tags in crel2tags.items():
            if len(tags) < 2:
                skipped_crels += 1
                continue
            supported_causal.add(crel)
            supported_codes.update(tags)
        
        if not supported_causal:
            skipped_sent += 1
            continue
        # filter out any tags that were only part of unsupported causal relations
        tag_seq = [tag for tag in tag_seq if tag in supported_codes]
        relations.append((supported_causal,tag_seq))
        
num_sents, len(relations), skipped_sent, skipped_crels

(8292, 2217, 86, 141)

## 4 Errors Below Look Are from Non-Projective Parses
**NOTES**
With only 4 errors as 4 missed relations, hardly worth worrying about. 
One solution would be to train a forward and a backward parser, parse the sentence in both directions and merge the deps. In each case that would pick up all deps.

In [14]:
errors = 0
exs = []
for supported_causal, tag_seq in relations[:]:
    supported_causal = sorted(supported_causal)
    crels = [normalize_cr(crel) for crel in supported_causal]

    try:
        success = test_oracle(tag_seq, crels, Oracle, verbose=False)
    except Exception as e:
        exs.append(e)
        success = False
        
    if not success:
        errors += 1
        print("Error for relations:")
        pprint(crels)
        pprint(tag_seq)
        #print()
        #success = test_oracle(tag_seq, crels, Oracle, verbose=True)
        #break

Error for relations:
[('1', '50'), ('1', '7'), ('3', '50'), ('3', '7')]
['1', '3', '50', '7']
Error for relations:
[('1', '4'), ('1', '50'), ('3', '4'), ('3', '50')]
['1', '3', '4', '50']
Error for relations:
[('13', '6'), ('13', '7'), ('3', '6'), ('3', '7'), ('7', '50')]
['6', '3', '13', '7', '50']
Error for relations:
[('12', '11'), ('13', '11'), ('13', '14')]
['12', '13', '11', '14']


## <span style="color:red">NEED to determine if all errors are non-projective<span>

### Define a New Oracle That Can be Consulted Without Acting

In [15]:
from collections import defaultdict

SHIFT = "Shift"
REDUCE = "Reduce"
LARC = "LArc"
RARC = "Rarc"
SKIP = "Skip"

class Oracle2(object):
    
    def __init__(self, crels, parser):
        self.parser = parser
        self.crels = norm_arcs(crels)
        self.mapping = self.build_mappings(crels)
    
    def build_mappings(self, pairs):
        mapping = defaultdict(set)
        for c,res in pairs:
            mapping[c].add(res)
            mapping[res].add(c)
        return mapping

    def should_continue(self, action):
        # continue parsing if REDUCE or LARC
        return action in (REDUCE,LARC)
    
    def remove_relation(self, a,b):
        self.mapping[a].remove(b)
        if len(self.mapping[a]) == 0:
            del self.mapping[a]
        self.mapping[b].remove(a)
        if len(self.mapping[b]) == 0:
            del self.mapping[b]
    
    def consult(self, tos, buffer):
        """
        Performs optimal decision for parser
        If true, continue processing, else Consume Buffer
        """
        parser = self.parser
        a,b = norm_arc((tos, buffer))
        if (a,b) in self.crels:
            # TOS has arcs remaining? If so, we need RARC, else LARC
            if len(self.mapping[tos]) == 1:
                return LARC
            else:
                return RARC
        else:
            if buffer not in self.mapping:
                return SKIP
            # If the buffer has relations further down in the stack, we need to POP the TOS
            for item in self.mapping[buffer]:
                if item == tos:
                    continue
                if parser.in_stack(item):
                    return REDUCE
            #end for
            #ELSE
            return SHIFT
        
    def execute(self, action, tos, buffer):
        """
        Performs optimal decision for parser
        If true, continue processing, else Consume Buffer
        """
        parser = self.parser
        if action == LARC:
            parser.left_arc(buffer)
            self.remove_relation(tos, buffer)
        elif action == RARC:
            parser.right_arc(buffer)
            self.remove_relation(tos, buffer)
        elif action == REDUCE:
            parser.reduce()
        elif action == SHIFT:
            parser.shift(buffer)
        elif action == SKIP:
            pass
        else:
            raise Exception("Unknown parsing action %s" % action)
        return self.should_continue(action)

In [16]:
from pprint import pprint

def test_oracle2(codes, crels, orcl_fact, verbose=False):
    
    crels = set(crels)
    if verbose:
        prn_fun = lambda s="": print(s)
    else:
        prn_fun = lambda s="": None
    
    stack = Stack(False)
    stack.push(ROOT)
    parser = Parser(stack)
    oracle = orcl_fact(crels, parser)

    prn_fun("DEPS")
    for crel in sorted(crels):
        prn_fun("\t" + str(crel))
    prn_fun()

    PAD = 20
    LINE = PAD + len(ROOT) + 2 * len(codes) + 1

    for buffer in codes:
        prn_fun("-" * LINE)
        prn_fun(buffer)
        prn_fun("-" * LINE)

        while True:
            tos = stack.tos()
            action = oracle.consult(tos, buffer)
            if not oracle.execute(action, tos, buffer):
                prn_fun(parser.actions[-1].ljust(PAD) + " || STACK : " + str(stack))
                break

            prn_fun(parser.actions[-1].ljust(PAD) + " || STACK : " + str(stack))
            if stack.len() == 0:
                prn_fun("Empty stack, stopping")
                break

    prn_fun()
    prn_fun("*" * LINE)
    prn_fun("Stack")
    prn_fun("\t" + str(stack))
    deps = parser.get_dependencies()
    prn_fun("DEPS Actual")
    for crel in sorted(crels):
        prn_fun("\t" + str(crel))
    prn_fun("DEPS Pred")
    for dep in sorted(deps):
        prn_fun("\t" + str(dep))
    prn_fun("Actions")
    for a in parser.actions:
        prn_fun("\t" + a)
    prn_fun()
    prn_fun("Ordered Match?    " + str(set(deps) == crels))

    ndeps = norm_arcs(deps)
    ncrels = norm_arcs(crels)
    diff = (ndeps - ncrels).union(ncrels - ndeps)
    success = (len(diff) == 0)
    prn_fun("Un Ordered Match? " + str(success))
    if diff:
        prn_fun(diff)
    return success

In [17]:
test_oracle2(['5', '50'], [('5', '50')], Oracle2, verbose=True)

DEPS
	('5', '50')

-----------------------------
5
-----------------------------
SHIFT   : Push 5     || STACK : root|5
-----------------------------
50
-----------------------------
L ARC   : 5->50      || STACK : root
L ARC   : 5->50      || STACK : root

*****************************
Stack
	root
DEPS Actual
	('5', '50')
DEPS Pred
	('5', '50')
Actions
	SHIFT   : Push 5
	L ARC   : 5->50

Ordered Match?    True
Un Ordered Match? True


True

In [18]:
errors = 0
exs = []
for supported_causal, tag_seq in relations[:]:
    supported_causal = sorted(supported_causal)
    crels = [normalize_cr(crel) for crel in supported_causal]

    try:
        success = test_oracle2(tag_seq, crels, Oracle2, verbose=False)
    except Exception as e:
        exs.append(e)
        success = False
        
    if not success:
        errors += 1
        print("Error for relations:")
        pprint(crels)
        pprint(tag_seq)
        #print()
        #success = test_oracle(tag_seq, crels, Oracle, verbose=True)
        #break

Error for relations:
[('1', '50'), ('1', '7'), ('3', '50'), ('3', '7')]
['1', '3', '50', '7']
Error for relations:
[('1', '4'), ('1', '50'), ('3', '4'), ('3', '50')]
['1', '3', '4', '50']
Error for relations:
[('13', '6'), ('13', '7'), ('3', '6'), ('3', '7'), ('7', '50')]
['6', '3', '13', '7', '50']
Error for relations:
[('12', '11'), ('13', '11'), ('13', '14')]
['12', '13', '11', '14']
