# Planning Notebook

In [104]:
import math
import unittest
import numpy as np
from itertools import product
import tqdm
from tqdm import tqdm_notebook

# import gtsam
import gtsam
from gtsam import *
from gtsam.utils.test_case import GtsamTestCase

# import gtbook
import gtbook
from gtbook.display import *
from gtbook.discrete import *

# import local package
import gtsam_example
from gtsam_example import SingleValueConstraint, MultiValueConstraint, NotSingleValueConstraint, OrConstraint, MutexConstraint, OperatorConstraint

# import parser
import SASParser
from SASParser import SAS, Operator

variables = Variables()
def pretty(obj): 
    return gtbook.display.pretty(obj, variables)

import graphviz
class show(graphviz.Source):
    """ Display an object with a dot method as a graph."""

    def __init__(self, obj):
        """Construct from object with 'dot' method."""
        # This small class takes an object, calls its dot function, and uses the
        # resulting string to initialize a graphviz.Source instance. This in turn
        # has a _repr_mimebundle_ method, which then renders it in the notebook.
        super().__init__(obj.dot())

In [105]:
class SASToGTSAM():
    def __init__(self, sas):
        self.sas = sas
        self.init = sas.initial_state
        self.goal = sas.goal
        self.vars = self.sas.variables
        self.ops = self.sas.operators
        self.mutex_groups = self.sas.mutex_group
        self.ops_names = []
        for op in self.ops:
            self.ops_names.append(op.name)

    def generate_string(self, cardinality, val):
        string_list = []
        for i in range(cardinality):
            string_list.append("0")
        string_list[val] = "1"
        string = ' '.join(string_list)
        return string
    
    def generate_state(self, timestep):
        state = []
        for var, val in self.vars.items():
            state_var = variables.discrete(str(var)+"_"+str(timestep), val)
            state.append(state_var)
        return state
    
    def generate_operator(self, timestep):
        op_var = variables.discrete("op_"+str(timestep), self.ops_names)
        return op_var
    
    def generate_initial_factor(self, initial_state):
        keys = gtsam.DiscreteKeys()
        for key in initial_state:
            keys.push_back(key)
        init_values = list(self.init.values())
        init_f = MultiValueConstraint(keys, init_values)
        return init_f

    def generate_goal_factor(self, goal_state):
        state = list(self.vars.keys())
        keys = gtsam.DiscreteKeys()
        vals = []
        for goal_var, goal_val in converter.goal.items():
            keys.push_back(goal_state[state.index(goal_var)])
            vals.append(goal_val)
        goal_f = MultiValueConstraint(keys, vals)
        return goal_f
    
    def generate_op_factor(self, state_t, state_tp, operator):
        state = list(self.vars.keys())
        keys = gtsam.DiscreteKeys()
        vals = []
        for pre_var, pre_val in op.precondition.items():
            keys.push_back(state_t[state.index(pre_var)])
            vals.append(pre_val)
        
        for eff_var, eff_val in op.effect.items():
            keys.push_back(state_tp[state.index(eff_var)])
            vals.append(eff_val)
        
        op_f = MultiValueConstraint(keys, vals)
        return op_f
        
        # TODO: precondition/effects should be multivalue constraints
        # state = list(self.vars.keys())
        # preconditions = operator.precondition
        # effects = operator.effect
        # prevail = operator.prevail
        # f = gtsam.DecisionTreeFactor()
        # for pre_var, pre_val in preconditions.items():
        #     if pre_val == -1:
        #         continue
        #     state_var = state_t[state.index(pre_var)]
        #     cardinality = state_var[1]
        #     state_f = gtsam.DecisionTreeFactor(state_var, self.generate_string(cardinality, pre_val))
        #     f *= state_f
        # for eff_var, eff_val in effects.items():
        #     if eff_val == -1:
        #         continue
        #     state_var = state_tp[state.index(eff_var)]
        #     cardinality = state_var[1]
        #     state_f = gtsam.DecisionTreeFactor(state_var, self.generate_string(cardinality, eff_val))
        #     f *= state_f
        # if prevail:
        #     for prev_var, prev_val in prevail.items():
        #         if prev_val == -1:
        #             continue
        #         state_var_t = state_t[state.index(prev_var)]
        #         state_var_tp = state_tp[state.index(prev_var)]
        #         cardinality = state_var[1]
        #         state_t_f = gtsam.DecisionTreeFactor(state_var_t, self.generate_string(cardinality, prev_val))
        #         state_tp_f = gtsam.DecisionTreeFactor(state_var_tp, self.generate_string(cardinality, prev_val))
        #         f *= state_t_f
        #         f *= state_tp_f
        # return f
    

    def valid(self, values, mutex):
        assert len(values) == len(mutex)
        count = 0
        for v, m in zip(values, mutex):
            if v == m:
                count += 1
            if count > 1:
                return "0"
        return "1"

    def generate_mutex_factor(self, state_t):
        state = list(self.vars.keys())
        mutex_variables = []
        mutex_values = []
        
        for mutex_group in self.mutex_groups:
            var_group = []
            val_group = []
            for var, val in mutex_group:
                state_var = state_t[state.index(var)]
                var_group.append(state_var)
                val_group.append(val)
            mutex_variables.append(var_group)
            mutex_values.append(val_group)
        
        factors = []
        for mutex_var, mutex_val in zip(mutex_variables, mutex_values):
            keys = gtsam.DiscreteKeys()
            for var in mutex_var:
                keys.push_back(var)
            mutex = MutexConstraint(keys, mutex_val)
            factors.append(mutex)
        return factors

In [106]:
sas = SAS()
sas_dir = "sas/block_example.sas"
sas.read_file(sas_dir)
converter = SASToGTSAM(sas)

In [107]:
class TestSASToGTSAM(GtsamTestCase):
    """Tests for Single Value Constraints"""

    def setUp(self):
        sas = SAS()
        sas_dir = "sas/block_example.sas"
        sas.read_file(sas_dir)
        self.converter = SASToGTSAM(sas)
        self.init_state = self.converter.generate_state(0)
        self.op = self.converter.generate_operator(0)
        self.next_state = self.converter.generate_state(1)

        input = []
        for _, vars in self.converter.vars.items():
            input.append(list(range(len(vars))))
        # tried this but this crashes the kernel
        # for _, vars in self.converter.vars.items():
        #     input.append(list(range(len(vars))))
        self.prods = list(product(*input))
    
    def createVal(self, states, prod):
        values = gtsam.DiscreteValues()
        for state, val in zip(states, prod):
            values[state[0]] = val
        return values

    def createOperatorVal(self, state1, state2, op):
        state = list(converter.vars.keys())
        values = gtsam.DiscreteValues()
        for var, val in op.precondition.items():
            if val == -1:
                continue
            values[state1[state.index(var)][0]] = val
        for var, val in op.effect.items():
            if val == -1:
                continue
            values[state2[state.index(var)][0]] = val
        return values
    
    def valid(self, prod, mutex_group):
        count = 0
        for var, val in mutex_group:
            if prod[var] == val:
                count += 1
            if count > 1:
                return 0.0
        return 1.0

    def test_generateState(self):
        assert len(self.init_state) == 9
    
    def test_generateString(self):
        cardinality = 3
        string = self.converter.generate_string(cardinality, 2)
        assert string == "0 0 1"

    def test_generateOperatorState(self):
        # there are 32 possible operators
        assert self.op[1] == 32

    def test_generateInitial(self):
        state = list(converter.vars.keys())
        initial_factor = self.converter.generate_initial_factor(self.init_state)
        total = 0.0
        for prod in self.prods:
            values = self.createVal(self.init_state, prod)
            output = initial_factor(values)
            if output == 1:
                for var, val in self.converter.init.items():
                    assert values[self.init_state[state.index(var)][0]] == val
            total += output
        assert total == 1

    def test_generateGoal(self):
        state = list(converter.vars.keys())
        goal_factor = self.converter.generate_goal_factor(self.next_state)
        total = 0.0
        for prod in self.prods:
            values = self.createVal(self.next_state, prod)
            output = goal_factor(values)
            if output == 1:
                for var, val in self.converter.goal.items():
                    assert values[self.next_state[state.index(var)][0]] == val
            total += output
        assert total > 1

    def test_generateOperatorFactor(self):
        for operator in converter.ops:
            op_factor = self.converter.generate_op_factor(self.init_state, self.next_state, operator)
            values = self.createOperatorVal(self.init_state, self.next_state, operator)
            assert op_factor(values) == 1
    
    # def test_generateMutex(self):
    #     mutex_factors = self.converter.generate_mutex_factor(self.init_state)
    #     for prod in tqdm_notebook(self.prods, desc='possible states'):
    #         for mutex_factor, mutex_group in zip(mutex_factors, converter.mutex_groups):
    #             check_valid = self.valid(prod, mutex_group)
    #             values = self.createVal(self.init_state, prod)
    #             factor_valid = mutex_factor(values)
    #             tree_factor = mutex_factor.toDecisionTreeFactor()
    #             tree_valid = tree_factor(values)
    #             assert tree_valid == factor_valid == check_valid

In [108]:
# unittest.main(argv=[''], verbosity=2, exit=False)

In [109]:
state_t = converter.generate_state(0)
op_t = converter.generate_operator(0)
state_tp = converter.generate_state(0)

In [110]:
state_t

[(0, 5), (1, 2), (2, 2), (3, 2), (4, 2), (5, 2), (6, 5), (7, 5), (8, 5)]

In [111]:
op_t

(9, 32)

In [112]:
state_tp

[(10, 5),
 (11, 2),
 (12, 2),
 (13, 2),
 (14, 2),
 (15, 2),
 (16, 5),
 (17, 5),
 (18, 5)]

In [115]:
op = converter.ops[0]
op.precondition

{1: 0, 5: 0, 0: 4}

In [116]:
op.effect

{1: 1, 5: 1, 0: 0}

In [117]:
def generate_op_factor(state_t, state_tp, op):
    state = list(converter.vars.keys())
    keys = gtsam.DiscreteKeys()
    vals = []
    for pre_var, pre_val in op.precondition.items():
        keys.push_back(state_t[state.index(pre_var)])
        vals.append(pre_val)
    
    for eff_var, eff_val in op.effect.items():
        keys.push_back(state_tp[state.index(eff_var)])
        vals.append(eff_val)
    
    op_f = MultiValueConstraint(keys, vals)
    return op_f
    

In [118]:
op_f = generate_op_factor(state_t, state_tp, op)

In [119]:
op_f

MultiValueConstraint on 1 5 0 11 15 10 

In [None]:
def generate_op_factor(state_t, state_tp, operator):
    # TODO: precondition/effects should be multivalue constraints
    state = list(converter.vars.keys())
    preconditions = operator.precondition
    effects = operator.effect
    prevail = operator.prevail
    f = gtsam.DecisionTreeFactor()
    for pre_var, pre_val in preconditions.items():
        if pre_val == -1:
            continue
        state_var = state_t[state.index(pre_var)]
        cardinality = state_var[1]
        state_f = gtsam.DecisionTreeFactor(state_var, self.generate_string(cardinality, pre_val))
        f *= state_f
    for eff_var, eff_val in effects.items():
        if eff_val == -1:
            continue
        state_var = state_tp[state.index(eff_var)]
        cardinality = state_var[1]
        state_f = gtsam.DecisionTreeFactor(state_var, self.generate_string(cardinality, eff_val))
        f *= state_f
    if prevail:
        for prev_var, prev_val in prevail.items():
            if prev_val == -1:
                continue
            state_var_t = state_t[state.index(prev_var)]
            state_var_tp = state_tp[state.index(prev_var)]
            cardinality = state_var[1]
            state_t_f = gtsam.DecisionTreeFactor(state_var_t, self.generate_string(cardinality, prev_val))
            state_tp_f = gtsam.DecisionTreeFactor(state_var_tp, self.generate_string(cardinality, prev_val))
            f *= state_t_f
            f *= state_tp_f
    return f

In [83]:
# def plan(k):
#     states = []
#     operators = []
#     mutex_factors = []
#     op_factors = []
#     for i in range(k):
#         # generate state
#         state_t = converter.generate_state(i)
#         # generate mutex factor for the state
#         mutex_factor = converter.generate_mutex_factor(state_t)
#         mutex_factors.append(mutex_factor)
#         # generate binary key indicating if operator satisfies all preconditions and effects
#         operators_t = converter.generate_operator(i)
#         states.append(state_t)
#         operators.append(operators_t)
#     last_state = converter.generate_state(k)
#     mutex_factor = converter.generate_mutex_factor(last_state)
#     mutex_factors.append(mutex_factor)
#     states.append(last_state)

#     for j in range(len(states)-1):
#         op_group = []
#         for op_t, op in zip(operators[j], converter.ops):
#             op_factor = converter.generate_op_factor(states[j], states[j+1], op_t, op)
#             op_group.append(op_factor)
#         op_factor = OrConstraint(op_group)
#         op_factors.append(op_factor)

#     initial_factor = converter.generate_initial_factor()
#     goal_factor = converter.generate_goal_factor(states[-1])
#     return states, initial_factor,  goal_factor, mutex_factors, op_factors

In [84]:
# states, initial_factor,  goal_factor, mutex_factors, op_factors = plan(7)

In [85]:
k = 2
states = []
operators = []
mutex_factors = []
op_factors = []
for i in range(k):
    # generate state
    state_t = converter.generate_state(i)
    # generate mutex factor for the state
    mutex_factor = converter.generate_mutex_factor(state_t)
    mutex_factors.append(mutex_factor)
    # generate binary key indicating if operator satisfies all preconditions and effects
    operators_t = converter.generate_operator(i)
    states.append(state_t)
    operators.append(operators_t)
last_state = converter.generate_state(k)
mutex_factor = converter.generate_mutex_factor(last_state)
mutex_factors.append(mutex_factor)
states.append(last_state)

In [86]:
for j in range(len(states)-1):
    op_group = []
    for op in converter.ops:
        op_factor = converter.generate_op_factor(states[j], states[j+1], op)
        op_group.append(op_factor)
    op_factor = OperatorConstraint(operators[j], op_group)
    op_factors.append(op_factor)

initial_factor = converter.generate_initial_factor(states[0])
goal_factor = converter.generate_goal_factor(states[-1])

In [87]:
# values = DiscreteValues()
# values[0] = 4
# values[1] = 0
# values[2] = 0
# values[3] = 0
# values[4] = 0
# values[5] = 0
# values[6] = 4
# values[7] = 4
# values[8] = 4
# #--------------
# values[9] = 1
# #--------------
# values[10] = 4
# values[11] = 0
# values[12] = 1
# values[13] = 0
# values[14] = 0
# values[15] = 1
# values[16] = 0
# values[17] = 4
# values[18] = 4

In [88]:
# operators[0]

In [89]:
# op0 = op_group[0]
# op1 = op_group[1]
# op2 = op_group[2]

In [90]:
# added = OperatorConstraint((9,3), [op0, op1, op2])

In [91]:
# added_tree = added.toDecisionTreeFactor()

In [92]:
# show(added_tree)

In [93]:
graph = gtsam.DiscreteFactorGraph()

In [94]:
for m_factor in mutex_factors:
    for f in m_factor:
        graph.push_back(f)

In [95]:
for op_factor in op_factors:
    graph.push_back(op_factor)

In [96]:
# graph.push_back(op_factors[1])
# graph.push_back(op_factors[43])
# graph.push_back(op_factors[66])
# graph.push_back(op_factors[111])
# graph.push_back(op_factors[131])
# graph.push_back(op_factors[179])
graph.push_back(goal_factor)
graph.push_back(initial_factor)

In [97]:
print(graph)


size: 19
factor 0: MutexConstraint on 1 0 6 7 8 
factor 1: MutexConstraint on 2 0 6 7 8 
factor 2: MutexConstraint on 3 0 6 7 8 
factor 3: MutexConstraint on 4 0 6 7 8 
factor 4: MutexConstraint on 5 0 6 7 8 
factor 5: MutexConstraint on 11 10 16 17 18 
factor 6: MutexConstraint on 12 10 16 17 18 
factor 7: MutexConstraint on 13 10 16 17 18 
factor 8: MutexConstraint on 14 10 16 17 18 
factor 9: MutexConstraint on 15 10 16 17 18 
factor 10: MutexConstraint on 21 20 26 27 28 
factor 11: MutexConstraint on 22 20 26 27 28 
factor 12: MutexConstraint on 23 20 26 27 28 
factor 13: MutexConstraint on 24 20 26 27 28 
factor 14: MutexConstraint on 25 20 26 27 28 
factor 15: OperatorConstraint on 9
factor 16: OperatorConstraint on 19
factor 17: MultiValueConstraint on 26 27 28 
factor 18: MultiValueConstraint on 0 1 2 3 4 5 6 7 8 



In [98]:
val = graph.optimize()

RuntimeError: An inference algorithm was called with inconsistent arguments.  The
factor graph, ordering, or variable index were inconsistent with each
other, or a full elimination routine was called with an ordering that
does not include all of the variables.

In [51]:
print(graph)


size: 19
factor 0: MutexConstraint on 1 0 6 7 8 
factor 1: MutexConstraint on 2 0 6 7 8 
factor 2: MutexConstraint on 3 0 6 7 8 
factor 3: MutexConstraint on 4 0 6 7 8 
factor 4: MutexConstraint on 5 0 6 7 8 
factor 5: MutexConstraint on 11 10 16 17 18 
factor 6: MutexConstraint on 12 10 16 17 18 
factor 7: MutexConstraint on 13 10 16 17 18 
factor 8: MutexConstraint on 14 10 16 17 18 
factor 9: MutexConstraint on 15 10 16 17 18 
factor 10: MutexConstraint on 21 20 26 27 28 
factor 11: MutexConstraint on 22 20 26 27 28 
factor 12: MutexConstraint on 23 20 26 27 28 
factor 13: MutexConstraint on 24 20 26 27 28 
factor 14: MutexConstraint on 25 20 26 27 28 
factor 15: OperatorConstraint on 9
factor 16: OperatorConstraint on 19
factor 17: MultiValueConstraint on 26 27 28 
factor 18: MultiValueConstraint on 0 1 2 3 4 5 6 7 8 



In [None]:
# print(graph)

In [None]:
graph(val)

In [None]:
a = (0, 2)
b = (1, 2)
c = (2, 4)

In [None]:
f_and = gtsam.DecisionTreeFactor([a, b], "0 0 0 1")
f_or = gtsam.DecisionTreeFactor([a, b], "0 1 1 1")
f_true = gtsam.DecisionTreeFactor([a, b], "0 0 0 0")
f_false = gtsam.DecisionTreeFactor([a, b], "1 1 1 1")

In [None]:
show(f_and)

In [None]:
add_constraint = OperatorConstraint(c, [f_and, f_or, f_true, f_false])

In [None]:
show(add_constraint.toDecisionTreeFactor())

In [None]:
or_constraint = OrConstraint([f_and, f_or])

In [None]:
show(or_constraint.toDecisionTreeFactor())

In [None]:
help(DecisionTreeFactor)

In [None]:
combined.toDecisionTreeFactor()

In [None]:
f_or.cardinality

In [None]:
f_and * f_or

In [None]:
for val in f_and.enumerate():
    print(val)

In [None]:
values = gtsam.DiscreteValues()
values[0] = 1
values[1] = 0
values[2] = 0

In [None]:
values.items

In [None]:
or_f = OrConstraint([f_and, f_or])

In [None]:
or_f_tree = or_f.toDecisionTreeFactor()
or_f_tree

In [None]:
or_f_tree * f_or

In [None]:
key_set = set()


In [None]:
op_var = [a,b,c]
input = []
for var in op_var:
    input.append(list(range(var[1])))
prods = list(product(*input))

In [None]:
prods

In [None]:
converter.mutex_groups

In [None]:
state_t = converter.generate_state(0)

In [None]:
state_t

In [None]:
state = list(converter.vars.keys())

In [None]:
state

In [None]:
prods = list(product(*input))

In [None]:
len(prods)

In [None]:
values = DiscreteValues()
values[]