In [22]:
class Stack(object):
    def __init__(self, verbose=False):    
        self.stack = []
        self.verbose = verbose
    
    def tos(self):
        if self.len() == 0:
            return None
        #assert self.len() > 0, "Can't peek when stack is empty"
        return self.stack[-1]
    
    def pop(self):
        assert self.len() > 0, "Can't pop when stack is empty"
        item = self.stack.pop()
        if self.verbose:
            print("POPPING: %s" % item)
            print("LEN:     %i" % len(self.stack))
        return item
    
    def push(self, item):
        self.stack.append(item)
        if self.verbose:
            print("PUSHING: %s" % item)
            print("LEN:     %i" % len(self.stack))
    
    def len(self):
        return len(self.stack)

    def contains(self, item):
        return item in self.stack
    
    def __repr__(self):
        return "|".join(self.stack)

In [23]:
ROOT = "root"

def norm_arc(arc):
    return tuple(sorted(arc))

def norm_arcs(arcs):
    return set(map(norm_arc, arcs))

class Parser(object):
    def __init__(self, stack):
        self.stack = stack
        self.arcs = []
        self.normed_arcs = set()
        # nodes with heads
        self.children = set()
        self.actions = []
        
    def get_dependencies(self):
        return [(l,r) for (l,r) in self.arcs if r != ROOT and l != ROOT]
        
    def left_arc(self, buffer):
        tos = self.stack.pop()
        arc = (tos,buffer)
        n_arc = norm_arc(arc)
        assert n_arc not in self.normed_arcs, "Arc already processed %s" % (n_arc)
        self.arcs.append(arc)
        self.normed_arcs.add(arc)
        self.children.add(tos)
        self.actions.append("L ARC   : " + tos + "->" + buffer)
        
    def right_arc(self, buffer):
        tos = self.stack.tos()
        #normalize arc
        arc = (buffer,tos)
        n_arc = norm_arc(arc)
        assert n_arc not in self.normed_arcs, "Arc already processed %s" % (n_arc)
        self.arcs.append(arc)
        self.normed_arcs.add(n_arc)
        self.actions.append("R ARC   : " + tos + "<-" + buffer)
        self.children.add(buffer)
        self.stack.push(buffer)
        
    def reduce(self):
        tos = self.stack.pop()
        self.actions.append("REDUCE  : Pop  %s" % tos)
        
    def shift(self, buffer):
        self.stack.push(buffer)
        self.actions.append("SHIFT   : Push %s" % buffer)
        
    def has_head(self, item):
        return item in self.children
    
    def in_stack(self, item):
        return self.stack.contains(item)

In [74]:
from collections import defaultdict

SHIFT = "Shift"
REDUCE = "Reduce"
LARC = "LArc"
RARC = "Rarc"

class Oracle(object):    
    def __init__(self, crels, parser):
        self.crels = set(crels)
        self.parser = parser
        self.mapping = self.build_mappings(crels)
    
    def build_mappings(self, pairs):
        mapping = defaultdict(set)
        for c,res in pairs:
            mapping[c].add(res)
            mapping[res].add(c)
        return mapping

    def cont(self, action):
        if action in (SHIFT,RARC):
            return False
        else:
            return True
    
    def remove_relation(self, a,b):
        self.mapping[a].remove(b)
        if len(self.mapping[a]) == 0:
            del self.mapping[a]
        self.mapping[b].remove(a)
        if len(self.mapping[b]) == 0:
            del self.mapping[b]
    
    def consult(self, tos, buffer):
        """
        Performs optimal decision for parser
        If true, continue processing, else Consume Buffer
        """
        parser = self.parser        
        if tos == ROOT:
            if buffer not in self.mapping:
                # map to root
                parser.right_arc(buffer)
                #then reduce?
                return self.cont(RARC)
            else:
                parser.shift(buffer)
                return self.cont(SHIFT)
            
        if (tos,buffer) in self.crels:
            # no other relations then discard
            if len(self.mapping[tos]) == 1:
                #LEFT ARC (tos not root, and tos not have a head)
                assert not parser.has_head(tos),  "%s already has a head #1" % tos
                # tos is child of head code
                parser.left_arc(buffer)
                self.remove_relation(tos, buffer)
                return self.cont(LARC)
            else:
                parser.right_arc(buffer)
                self.remove_relation(tos, buffer)
                return self.cont(RARC)
        elif (buffer, tos) in self.crels:
            # if the buffer has multiple relations, and one or more in in the stack, we need a left arc
            if len(self.mapping[buffer]) > 1:
                #print("Multiple relations for buffer")
                for item in self.mapping[buffer]:                    
                    if self.parser.in_stack(item):
                        assert not parser.has_head(tos),  "%s already has a head #2" % tos
                        parser.left_arc(buffer)
                        self.remove_relation(tos, buffer)
                        return self.cont(LARC)
            
            parser.right_arc(buffer)  
            self.remove_relation(tos, buffer)
            return self.cont(RARC)
        else:
            if tos not in self.mapping:
                parser.reduce()
                return self.cont(REDUCE)
            elif parser.has_head(stack.tos()):
                parser.reduce()
                return self.cont(REDUCE)
            else:
                parser.shift(buffer)
                return self.cont(SHIFT)

In [75]:
###################
#INPUTS
###################
codes = "ABCDE"
crels = set([
    ("A","C"),
   # ("B","C"),
   # ("C","D"),
   # ("C","E")
])

###################

stack = Stack(False)
stack.push(ROOT)
parser = Parser(stack)
oracle = Oracle(crels, parser)

print("DEPS")
for crel in sorted(crels):
    print("\t" + str(crel))
print()

PAD = 20
LINE = PAD + len(ROOT) + 2 * len(codes) + 1

for buffer in list(codes):
    print("-" * LINE)
    print(buffer)
    print("-" * LINE)
        
    while True:
        tos = stack.tos()
        if not oracle.consult(tos, buffer):
            print(parser.actions[-1].ljust(PAD) + " || STACK : " + str(stack))
            break

        print(parser.actions[-1].ljust(PAD) + " || STACK : " + str(stack))
        if stack.len() == 0:
            print("Empty stack, stopping")
            break

print()
print("*" * LINE)
print("Stack")
print("\t" + str(stack))
deps = parser.get_dependencies()
print("DEPS Actual")
for crel in sorted(crels):
    print("\t" + str(crel))
print("DEPS Pred")
for dep in sorted(deps):
    print("\t" + str(dep))
print("Actions")
for a in parser.actions:
    print("\t" + a)
print()
print("Ordered Match?    " + str(set(deps) == crels))

ndeps = norm_arcs(deps)
ncrels = norm_arcs(crels)
diff = (ndeps - ncrels).union(ncrels - ndeps)
print("Un Ordered Match? " + str(len(diff) == 0))
if diff:
    print(diff)

DEPS
	('A', 'C')

-----------------------------------
A
-----------------------------------
SHIFT   : Push A     || STACK : root|A
-----------------------------------
B
-----------------------------------
SHIFT   : Push B     || STACK : root|A|B
-----------------------------------
C
-----------------------------------
REDUCE  : Pop  B     || STACK : root|A
L ARC   : A->C       || STACK : root
R ARC   : root<-C    || STACK : root|C
-----------------------------------
D
-----------------------------------
REDUCE  : Pop  C     || STACK : root
R ARC   : root<-D    || STACK : root|D
-----------------------------------
E
-----------------------------------
REDUCE  : Pop  D     || STACK : root
R ARC   : root<-E    || STACK : root|E

***********************************
Stack
	root|E
DEPS Actual
	('A', 'C')
DEPS Pred
	('A', 'C')
Actions
	SHIFT   : Push A
	SHIFT   : Push B
	REDUCE  : Pop  B
	L ARC   : A->C
	R ARC   : root<-C
	REDUCE  : Pop  C
	R ARC   : root<-D
	REDUCE  : Pop  D
	R ARC   : root<

In [84]:
from pprint import pprint

def test_oracle(codes, crels, verbose=False):
    
    crels = set(crels)
    if verbose:
        prn_fun = lambda s="": print(s)
    else:
        prn_fun = lambda s="": None
    
    stack = Stack(False)
    stack.push(ROOT)
    parser = Parser(stack)
    oracle = Oracle(crels, parser)

    prn_fun("DEPS")
    for crel in sorted(crels):
        prn_fun("\t" + str(crel))
    prn_fun()

    PAD = 20
    LINE = PAD + len(ROOT) + 2 * len(codes) + 1

    for buffer in list(codes):
        prn_fun("-" * LINE)
        prn_fun(buffer)
        prn_fun("-" * LINE)

        while True:
            tos = stack.tos()
            if not oracle.consult(tos, buffer):
                prn_fun(parser.actions[-1].ljust(PAD) + " || STACK : " + str(stack))
                break

            prn_fun(parser.actions[-1].ljust(PAD) + " || STACK : " + str(stack))
            if stack.len() == 0:
                prn_fun("Empty stack, stopping")
                break

    prn_fun()
    prn_fun("*" * LINE)
    prn_fun("Stack")
    prn_fun("\t" + str(stack))
    deps = parser.get_dependencies()
    prn_fun("DEPS Actual")
    for crel in sorted(crels):
        prn_fun("\t" + str(crel))
    prn_fun("DEPS Pred")
    for dep in sorted(deps):
        prn_fun("\t" + str(dep))
    prn_fun("Actions")
    for a in parser.actions:
        prn_fun("\t" + a)
    prn_fun()
    prn_fun("Ordered Match?    " + str(set(deps) == crels))

    ndeps = norm_arcs(deps)
    ncrels = norm_arcs(crels)
    diff = (ndeps - ncrels).union(ncrels - ndeps)
    success = (len(diff) == 0)
    prn_fun("Un Ordered Match? " + str(success))
    if diff:
        prn_fun(diff)
    return success

In [88]:
test_pairs = []

test_pairs.append([
    ("A","B"),
])
test_pairs.append([
    ("A","B"),
    ("B","C"),
])
#C->B->A
test_pairs.append([
    ("C","B"),
    ("B","A"),
])
test_pairs.append([
    ("A","C"),
    ("B","C"),
])
test_pairs.append([
    ("A","B"),
    ("C","B"),
])
test_pairs.append([
    ("B","A"),
    ("B","C"),
])
test_pairs.append([
    ("A","C"),
    ("C","B"),
])

# Hard - has to flip relation
test_pairs.append([
    ("A","D"),
    ("D","B"),
    ("B","C"),
])
test_pairs.append([
    ("D","A"),
    ("D","B"),
    ("B","C"),
])
test_pairs.append([
    ("D","A"),
    ("B","D"),
    ("B","C"),
])

test_pairs.append([
    ("A","E"),
    ("E","B"),
    ("B","D"),
    ("D","C"),
])
test_pairs.append([
    ("A","D"),
    ("D","B"),
    ("B","C"),
    ("A", "F"),
    ("A", "E"),
])

test_pairs.append([
    ("A","D"),
    ("D","B"),
    ("B","C"),
    ("A", "F"),
    ("E", "F"),
])

for pairs in test_pairs:
    success = test_oracle("ABCDEF", pairs, verbose=False)
    if not success:
        print("Error for relations:")
        pprint(pairs)
        success = test_oracle("ABCDEF", pairs, verbose=True)

In [86]:
class OracleSimpler(Oracle):
    
    def consult(self, tos, buffer):
        """
        Performs optimal decision for parser
        If true, continue processing, else Consume Buffer
        """
        action = None
        parser = self.parser
        
        if tos == ROOT:
            if buffer not in self.mapping:
                # map to root
                parser.right_arc(buffer)
                #then reduce?
                return self.cont(RARC)
            else:
                parser.shift(buffer)
                return self.cont(SHIFT)
            
        if (tos,buffer) in self.crels:
            # no other relations then discard
            if len(self.mapping[tos]) == 1:
                #LEFT ARC (tos not root, and tos not have a head)
                assert not parser.has_head(tos),  "%s already has a head #1" % tos
                # tos is child of head code
                parser.left_arc(buffer)
                self.remove_relation(tos, buffer)
                return self.cont(LARC)
            else:
                parser.right_arc(buffer)
                self.remove_relation(tos, buffer)
                return self.cont(RARC)
        elif (buffer, tos) in self.crels:
            # if the buffer has multiple relations, and one or more in in the stack, we need a left arc
            if len(self.mapping[buffer]) > 1:
                #print("Multiple relations for buffer")
                for item in self.mapping[buffer]:                    
                    if self.parser.in_stack(item):
                        assert not parser.has_head(tos),  "%s already has a head #2" % tos
                        parser.left_arc(buffer)
                        self.remove_relation(tos, buffer)
                        return self.cont(LARC)
            
            parser.right_arc(buffer)  
            self.remove_relation(tos, buffer)
            return self.cont(RARC)
        else:
            if tos not in self.mapping:
                parser.reduce()
                return self.cont(REDUCE)
            elif parser.has_head(stack.tos()):
                parser.reduce()
                return self.cont(REDUCE)
            else:
                parser.shift(buffer)
                return self.cont(SHIFT)

In [87]:
orc2 = OracleSimpler(crels, parser)
orc2.consult("A", "B")

NameError: name 'crels' is not defined