# Planning Notebook

In [17]:
import math
import unittest
import numpy as np
from itertools import product

# import gtsam
import gtsam
from gtsam import *
from gtsam.utils.test_case import GtsamTestCase

# import gtbook
import gtbook
from gtbook.display import *
from gtbook.discrete import *

# import local package
import gtsam_example
from gtsam_example import SingleValueConstraint, MultiValueConstraint, NotSingleValueConstraint, OrConstraint, MutexConstraint, OperatorConstraint

# import parser
import SASParser
from SASParser import SAS, Operator

variables = Variables()
def pretty(obj): 
    return gtbook.display.pretty(obj, variables)

import graphviz
class show(graphviz.Source):
    """ Display an object with a dot method as a graph."""

    def __init__(self, obj):
        """Construct from object with 'dot' method."""
        # This small class takes an object, calls its dot function, and uses the
        # resulting string to initialize a graphviz.Source instance. This in turn
        # has a _repr_mimebundle_ method, which then renders it in the notebook.
        super().__init__(obj.dot())

In [18]:
class SASToGTSAM():
    def __init__(self, sas):
        self.sas = sas
        self.init = sas.initial_state
        self.goal = sas.goal
        self.vars = self.sas.variables
        self.ops = self.sas.operators
        self.mutex_groups = self.sas.mutex_group
        self.ops_names = []
        for op in self.ops:
            self.ops_names.append(op.name)

    def generate_string(self, cardinality, val):
        string_list = []
        for i in range(cardinality):
            string_list.append("0")
        string_list[val] = "1"
        string = ' '.join(string_list)
        return string
    
    def generate_state(self, timestep):
        state = []
        for var, val in self.vars.items():
            state_var = variables.discrete(str(var)+"_"+str(timestep), val)
            state.append(state_var)
        return state
    
    def generate_operator(self, timestep):
        op_var = variables.discrete("op_"+str(timestep), self.ops_names)
        return op_var
    
    def generate_initial_factor(self, initial_state):
        keys = gtsam.DiscreteKeys()
        for key in initial_state:
            keys.push_back(key)
        init_values = list(self.init.values())
        init_f = MultiValueConstraint(keys, init_values)
        return init_f

    def generate_goal_factor(self, goal_state):
        state = list(self.vars.keys())
        keys = gtsam.DiscreteKeys()
        vals = []
        for goal_var, goal_val in converter.goal.items():
            keys.push_back(goal_state[state.index(goal_var)])
            vals.append(goal_val)
        goal_f = MultiValueConstraint(keys, vals)
        return goal_f
    
    def generate_op_factor(self, state_t, state_tp, operator):
        state = list(self.vars.keys())
        preconditions = operator.precondition
        effects = operator.effect
        prevail = operator.prevail
        f = gtsam.DecisionTreeFactor()
        for pre_var, pre_val in preconditions.items():
            if pre_val == -1:
                continue
            state_var = state_t[state.index(pre_var)]
            cardinality = state_var[1]
            state_f = gtsam.DecisionTreeFactor(state_var, self.generate_string(cardinality, pre_val))
            f *= state_f
        for eff_var, eff_val in effects.items():
            if eff_val == -1:
                continue
            state_var = state_tp[state.index(eff_var)]
            cardinality = state_var[1]
            state_f = gtsam.DecisionTreeFactor(state_var, self.generate_string(cardinality, eff_val))
            f *= state_f
        if prevail:
            for prev_var, prev_val in prevail.items():
                if prev_val == -1:
                    continue
                state_var_t = state_t[state.index(prev_var)]
                state_var_tp = state_tp[state.index(prev_var)]
                cardinality = state_var[1]
                state_t_f = gtsam.DecisionTreeFactor(state_var_t, self.generate_string(cardinality, prev_val))
                state_tp_f = gtsam.DecisionTreeFactor(state_var_tp, self.generate_string(cardinality, prev_val))
                f *= state_t_f
                f *= state_tp_f
        return f

    def valid(self, values, mutex):
        assert len(values) == len(mutex)
        count = 0
        for v, m in zip(values, mutex):
            if v == m:
                count += 1
            if count > 1:
                return "0"
        return "1"

    def generate_mutex_factor(self, state_t):
        state = list(self.vars.keys())
        mutex_variables = []
        mutex_values = []
        
        for mutex_group in self.mutex_groups:
            var_group = []
            val_group = []
            for var, val in mutex_group:
                state_var = state_t[state.index(var)]
                var_group.append(state_var)
                val_group.append(val)
            mutex_variables.append(var_group)
            mutex_values.append(val_group)
        
        factors = []
        for mutex_var, mutex_val in zip(mutex_variables, mutex_values):
            keys = gtsam.DiscreteKeys()
            for var in mutex_var:
                keys.push_back(var)
            mutex = MutexConstraint(keys, mutex_val)
            factors.append(mutex)
        return factors
        
        # f = gtsam.DecisionTreeFactor()
        # for factor in factors:
        #     f = f * factor.toDecisionTreeFactor()
        # for mutex_var, mutex_val in zip(mutex_variables, mutex_values):
        #     input = []
        #     for var in mutex_var:
        #         input.append(list(range(var[1])))
        #     prods = list(product(*input))
        #     mutex_string_list = []
        #     for prod in prods:
        #         is_valid = self.valid(prod, mutex_val)
        #         mutex_string_list.append(is_valid)
        #     mutex_string = ' '.join(mutex_string_list)
        #     f = f*gtsam.DecisionTreeFactor(mutex_var, mutex_string)
        # return f

In [19]:
sas = SAS()
sas_dir = "sas/block_example.sas"
sas.read_file(sas_dir)
converter = SASToGTSAM(sas)

In [14]:
class TestSASToGTSAM(GtsamTestCase):
    """Tests for Single Value Constraints"""

    def setUp(self):
        sas = SAS()
        sas_dir = "sas/block_example.sas"
        sas.read_file(sas_dir)
        self.converter = SASToGTSAM(sas)
        self.init_state = self.converter.generate_state(0)
        self.op = self.converter.generate_operator(0)
        self.next_state = self.converter.generate_state(1)

        input = []
        for _, vars in self.converter.vars.items():
            input.append(list(range(len(vars))))
        # tried this but this crashes the kernel
        # for _, vars in self.converter.vars.items():
        #     input.append(list(range(len(vars))))
        self.prods = list(product(*input))
    
    def createVal(self, states, prod):
        values = gtsam.DiscreteValues()
        for state, val in zip(states, prod):
            values[state[0]] = val
        return values

    def createOperatorVal(self, state1, state2, op):
        state = list(converter.vars.keys())
        values = gtsam.DiscreteValues()
        for var, val in op.precondition.items():
            if val == -1:
                continue
            values[state1[state.index(var)][0]] = val
        for var, val in op.effect.items():
            if val == -1:
                continue
            values[state2[state.index(var)][0]] = val
        return values
    
    def valid(self, prod, mutex_group):
        count = 0
        for var, val in mutex_group:
            if prod[var] == val:
                count += 1
            if count > 1:
                return 0.0
        return 1.0

    def test_generateState(self):
        assert len(self.init_state) == 9
    
    def test_generateString(self):
        cardinality = 3
        string = self.converter.generate_string(cardinality, 2)
        assert string == "0 0 1"

    def test_generateOperator(self):
        # there are 32 possible operators
        assert self.op[1] == 32

    def test_generateInitial(self):
        state = list(converter.vars.keys())
        initial_factor = self.converter.generate_initial_factor(self.init_state)
        total = 0.0
        for prod in self.prods:
            values = self.createVal(self.init_state, prod)
            output = initial_factor(values)
            if output == 1:
                for var, val in self.converter.init.items():
                    assert values[self.init_state[state.index(var)][0]] == val
            total += output
        assert total == 1

    def test_generateGoal(self):
        state = list(converter.vars.keys())
        goal_factor = self.converter.generate_goal_factor(self.next_state)
        total = 0.0
        for prod in self.prods:
            values = self.createVal(self.next_state, prod)
            output = goal_factor(values)
            if output == 1:
                for var, val in self.converter.goal.items():
                    assert values[self.next_state[state.index(var)][0]] == val
            total += output
        assert total > 1

    def test_generateOperator(self):
        for operator in converter.ops:
            op_factor = self.converter.generate_op_factor(self.init_state, self.next_state, operator)
            values = self.createOperatorVal(self.init_state, self.next_state, operator)
            assert op_factor(values) == 1
    
    def test_generateMutex(self):
        mutex_factors = self.converter.generate_mutex_factor(self.init_state)
        for prod in self.prods:
            for mutex_factor, mutex_group in zip(mutex_factors, converter.mutex_groups):
                check_valid = self.valid(prod, mutex_group)
                values = self.createVal(self.init_state, prod)
                factor_valid = mutex_factor(values)
                tree_factor = mutex_factor.toDecisionTreeFactor()
                tree_valid = tree_factor(values)
                if tree_valid != factor_valid:
                    print(tree_valid, check_valid, prod)
                    return ""
            # values = self.createVal(self.init_state, prod)
            # if mutex_factor(values) != is_valid:
            #     print(is_valid, prod)
            #     print(values)

In [15]:
converter.mutex_groups

[[[1, 0], [0, 0], [6, 1], [7, 1], [8, 1]],
 [[2, 0], [0, 1], [6, 0], [7, 2], [8, 2]],
 [[3, 0], [0, 2], [6, 2], [7, 0], [8, 3]],
 [[4, 0], [0, 3], [6, 3], [7, 3], [8, 0]],
 [[5, 0], [0, 0], [6, 0], [7, 0], [8, 0]]]

In [16]:
unittest.main(argv=[''], verbosity=2, exit=False)

test_generateGoal (__main__.TestSASToGTSAM) ... ok
test_generateInitial (__main__.TestSASToGTSAM) ... ok
test_generateMutex (__main__.TestSASToGTSAM) ... ok
test_generateOperator (__main__.TestSASToGTSAM) ... ok
test_generateState (__main__.TestSASToGTSAM) ... ok
test_generateString (__main__.TestSASToGTSAM) ... 

0.0 1.0 (0, 0, 0, 0, 0, 0, 0, 0, 1)


ok

----------------------------------------------------------------------
Ran 6 tests in 1.472s

OK


<unittest.main.TestProgram at 0x7f07c3b3efd0>

In [20]:
values = DiscreteValues()
values[0] = 0
values[1] = 0
values[2] = 0
values[3] = 0
values[4] = 0
values[5] = 0
values[6] = 0
values[7] = 0
values[8] = 1

In [21]:
state_t = converter.generate_state(0)

In [22]:
mutex_factors = converter.generate_mutex_factor(state_t)

In [23]:
for f in mutex_factors:
    print(f(values))
    

0.0
0.0
0.0
1.0
0.0


In [24]:
converter.mutex_groups

[[[1, 0], [0, 0], [6, 1], [7, 1], [8, 1]],
 [[2, 0], [0, 1], [6, 0], [7, 2], [8, 2]],
 [[3, 0], [0, 2], [6, 2], [7, 0], [8, 3]],
 [[4, 0], [0, 3], [6, 3], [7, 3], [8, 0]],
 [[5, 0], [0, 0], [6, 0], [7, 0], [8, 0]]]

In [25]:
mutex_factors[3]

MutexConstraint on 4 0 6 7 8 

In [26]:
for f in mutex_factors:
    tree = f.toDecisionTreeFactor()
    print(tree(values))

0.0
0.0
0.0
0.0
0.0


In [43]:
tree = mutex_factors[3].toDecisionTreeFactor()

In [44]:
values

DiscreteValues{0: 0, 1: 0, 2: 0, 3: 0, 4: 1, 5: 0, 6: 0, 7: 0, 8: 1}

In [45]:
tree(values)

0.0

In [46]:
tree

8,7,6,4,0,value
0,0,0,0,0,0
0,0,0,0,1,0
0,0,0,0,2,0
0,0,0,0,3,0
0,0,0,0,4,0
0,0,0,1,0,1
0,0,0,1,1,1
0,0,0,1,2,1
0,0,0,1,3,0
0,0,0,1,4,1


In [None]:
# def plan(k):
#     states = []
#     operators = []
#     mutex_factors = []
#     op_factors = []
#     for i in range(k):
#         # generate state
#         state_t = converter.generate_state(i)
#         # generate mutex factor for the state
#         mutex_factor = converter.generate_mutex_factor(state_t)
#         mutex_factors.append(mutex_factor)
#         # generate binary key indicating if operator satisfies all preconditions and effects
#         operators_t = converter.generate_operator(i)
#         states.append(state_t)
#         operators.append(operators_t)
#     last_state = converter.generate_state(k)
#     mutex_factor = converter.generate_mutex_factor(last_state)
#     mutex_factors.append(mutex_factor)
#     states.append(last_state)

#     for j in range(len(states)-1):
#         op_group = []
#         for op_t, op in zip(operators[j], converter.ops):
#             op_factor = converter.generate_op_factor(states[j], states[j+1], op_t, op)
#             op_group.append(op_factor)
#         op_factor = OrConstraint(op_group)
#         op_factors.append(op_factor)

#     initial_factor = converter.generate_initial_factor()
#     goal_factor = converter.generate_goal_factor(states[-1])
#     return states, initial_factor,  goal_factor, mutex_factors, op_factors

In [None]:
# states, initial_factor,  goal_factor, mutex_factors, op_factors = plan(7)

In [None]:
k = 2
states = []
operators = []
mutex_factors = []
op_factors = []
for i in range(k):
    # generate state
    state_t = converter.generate_state(i)
    # generate mutex factor for the state
    mutex_factor = converter.generate_mutex_factor(state_t)
    mutex_factors.append(mutex_factor)
    # generate binary key indicating if operator satisfies all preconditions and effects
    operators_t = converter.generate_operator(i)
    states.append(state_t)
    operators.append(operators_t)
last_state = converter.generate_state(k)
mutex_factor = converter.generate_mutex_factor(last_state)
mutex_factors.append(mutex_factor)
states.append(last_state)

# for j in range(len(states)-1):
#     op_group = []
#     for op_t, op in zip(operators[j], converter.ops):
#         op_factor = converter.generate_op_factor(states[j], states[j+1], op)
#         op_group.append(op_factor)
#     op_factor = OperatorConstraint(operators[j], op_group)
#     op_factors.append(op_factor)

# initial_factor = converter.generate_initial_factor()
# goal_factor = converter.generate_goal_factor(states[-1])

In [None]:
for j in range(len(states)-1):
    op_group = []
    for op in converter.ops:
        op_factor = converter.generate_op_factor(states[j], states[j+1], op)
        op_group.append(op_factor)
    op_factor = OperatorConstraint(operators[j], op_group)
    op_factors.append(op_factor)

initial_factor = converter.generate_initial_factor(states[0])
goal_factor = converter.generate_goal_factor(states[-1])

In [None]:
values = DiscreteValues()
values[0] = 4
values[1] = 0
values[2] = 0
values[3] = 0
values[4] = 0
values[5] = 0
values[6] = 4
values[7] = 4
values[8] = 4
#--------------
values[9] = 1
#--------------
values[10] = 4
values[11] = 0
values[12] = 1
values[13] = 0
values[14] = 0
values[15] = 1
values[16] = 0
values[17] = 4
values[18] = 4

In [None]:
# operators[0]

In [None]:
# op0 = op_group[0]
# op1 = op_group[1]
# op2 = op_group[2]

In [None]:
# added = OperatorConstraint((9,3), [op0, op1, op2])

In [None]:
# added_tree = added.toDecisionTreeFactor()

In [None]:
# show(added_tree)

In [None]:
graph = gtsam.DiscreteFactorGraph()

In [None]:
for m_factor in mutex_factors:
    graph.push_back(m_factor)

In [None]:
for op_factor in op_factors:
    graph.push_back(op_factor)

In [None]:
# graph.push_back(op_factors[1])
# graph.push_back(op_factors[43])
# graph.push_back(op_factors[66])
# graph.push_back(op_factors[111])
# graph.push_back(op_factors[131])
# graph.push_back(op_factors[179])
graph.push_back(goal_factor)
graph.push_back(initial_factor)

In [None]:
val = graph.optimize()

In [None]:
# print(graph)

In [None]:
graph(val)

In [None]:
a = (0, 2)
b = (1, 2)
c = (2, 4)

In [None]:
f_and = gtsam.DecisionTreeFactor([a, b], "0 0 0 1")
f_or = gtsam.DecisionTreeFactor([a, b], "0 1 1 1")
f_true = gtsam.DecisionTreeFactor([a, b], "0 0 0 0")
f_false = gtsam.DecisionTreeFactor([a, b], "1 1 1 1")

In [None]:
show(f_and)

In [None]:
add_constraint = OperatorConstraint(c, [f_and, f_or, f_true, f_false])

In [None]:
show(add_constraint.toDecisionTreeFactor())

In [None]:
or_constraint = OrConstraint([f_and, f_or])

In [None]:
show(or_constraint.toDecisionTreeFactor())

In [None]:
help(DecisionTreeFactor)

In [None]:
combined.toDecisionTreeFactor()

In [None]:
f_or.cardinality

In [None]:
f_and * f_or

In [None]:
for val in f_and.enumerate():
    print(val)

In [None]:
values = gtsam.DiscreteValues()
values[0] = 1
values[1] = 0
values[2] = 0

In [None]:
values.items

In [None]:
or_f = OrConstraint([f_and, f_or])

In [None]:
or_f_tree = or_f.toDecisionTreeFactor()
or_f_tree

In [None]:
or_f_tree * f_or

In [None]:
key_set = set()


In [None]:
op_var = [a,b,c]
input = []
for var in op_var:
    input.append(list(range(var[1])))
prods = list(product(*input))

In [None]:
prods

In [None]:
converter.mutex_groups

In [None]:
state_t = converter.generate_state(0)

In [None]:
state_t

In [None]:
state = list(converter.vars.keys())

In [None]:
state

In [None]:
prods = list(product(*input))

In [None]:
len(prods)

In [None]:
values = DiscreteValues()
values[]