# Planning Notebook

In [1]:
import math
import unittest
import numpy as np
from itertools import product

# import gtsam
import gtsam
from gtsam import *
from gtsam.utils.test_case import GtsamTestCase

# import gtbook
import gtbook
from gtbook.display import *
from gtbook.discrete import *

# import local package
import gtsam_example
from gtsam_example import SingleValueConstraint, MultiValueConstraint, NotSingleValueConstraint, OrConstraint

# import parser
import SASParser
from SASParser import SAS, Operator

variables = Variables()
def pretty(obj): 
    return gtbook.display.pretty(obj, variables)

import graphviz
class show(graphviz.Source):
    """ Display an object with a dot method as a graph."""

    def __init__(self, obj):
        """Construct from object with 'dot' method."""
        # This small class takes an object, calls its dot function, and uses the
        # resulting string to initialize a graphviz.Source instance. This in turn
        # has a _repr_mimebundle_ method, which then renders it in the notebook.
        super().__init__(obj.dot())

In [2]:
class TestOrConstraint(GtsamTestCase):
    """Tests for Single Value Constraints"""

    def setUp(self):
        self.keys = DiscreteKeys()
        key_list = [(0, 2), (1, 2), (2, 2)]
        for key in key_list:
            self.keys.push_back(key)
        f_and = DecisionTreeFactor([key_list[0], key_list[1]], "0 0 0 1")
        f_or = DecisionTreeFactor([key_list[1], key_list[2]], "0 1 1 1")
        self.constraint = OrConstraint([f_and, f_or])

    def test_operator(self):
        values = DiscreteValues()
        values[self.keys.at(0)[0]] = 1
        values[self.keys.at(1)[0]] = 0
        values[self.keys.at(2)[0]] = 1
        self.assertEqual(self.constraint(values), 1.0)
    
    def test_toDecisionTree(self):
        expected = self.constraint.toDecisionTreeFactor()
        self.assertIsInstance(expected, DecisionTreeFactor)
        self.gtsamAssertEquals(DecisionTreeFactor(self.keys, "0 0 1 1 1 1 1 1"), expected)


In [3]:
keys = DiscreteKeys()
key_list = [(0, 2), (1, 2), (2, 2)]
for key in key_list:
    keys.push_back(key)
f_and = DecisionTreeFactor([key_list[0], key_list[1]], "0 0 0 1")
f_or = DecisionTreeFactor([key_list[1], key_list[2]], "0 1 1 1")
constraint = OrConstraint([f_and, f_or])

In [12]:
values = DiscreteValues()
values[keys.at(0)[0]] = 1
values[keys.at(1)[0]] = 0
values[keys.at(2)[0]] = 1

In [13]:
f_and

0,1,value
0,0,0
0,1,0
1,0,0
1,1,1


In [14]:
f_or

1,2,value
0,0,0
0,1,1
1,0,1
1,1,1


In [15]:
f_and * f_or

0,1,2,value
0,0,0,0
0,0,1,0
0,1,0,0
0,1,1,0
1,0,0,0
1,0,1,0
1,1,0,1
1,1,1,1


In [16]:
tree_constraint = constraint.toDecisionTreeFactor()
tree_constraint

2,1,0,value
0,0,0,0
0,0,1,0
0,1,0,1
0,1,1,1
1,0,0,1
1,0,1,1
1,1,0,1
1,1,1,1


In [20]:
f_and(values)

0.0

In [17]:
f_or(values)

1.0

In [18]:
constraint(values)

1.0

In [19]:
tree_constraint(values)

1.0

In [3]:
unittest.main(argv=[''], verbosity=2, exit=False)

In [42]:
class SASToGTSAM():
    def __init__(self, sas):
        self.sas = sas
        self.init = sas.initial_state
        self.goal = sas.goal
        self.vars = self.sas.variables
        self.ops = self.sas.operators
        self.mutex_groups = self.sas.mutex_group
        self.ops_names = []
        for op in self.ops:
            self.ops_names.append(op.name)

    def generate_string(self, cardinality, val):
        string_list = []
        for i in range(cardinality):
            string_list.append("0")
        string_list[val] = "1"
        string = ' '.join(string_list)
        return string
    
    def generate_state(self, timestep):
        state = []
        for var, val in self.vars.items():
            state_var = variables.discrete(str(var)+"_"+str(timestep), val)
            state.append(state_var)
        return state
    
    def generate_operator(self, timestep):
        operators = []
        for op in self.ops:
            op_var = variables.binary(op.name+"_"+str(timestep))
            operators.append(op_var)
        return operators
    
    def generate_initial_factor(self):
        initial_state = self.generate_state(0)
        keys = gtsam.DiscreteKeys()
        for key in initial_state:
            keys.push_back(key)
        init_values = list(self.init.values())
        multi_f = MultiValueConstraint(keys, init_values)
        return multi_f

    def generate_goal_factor(self, goal_state):
        state = list(self.vars.keys())
        goal_f = gtsam.DecisionTreeFactor()
        for goal_var, goal_val in self.goal.items():
            state_var = goal_state[state.index(goal_var)]
            single_f = SingleValueConstraint(state_var, goal_val)
            f = single_f.toDecisionTreeFactor()
            goal_f *= f
        return goal_f

    def generate_op_factor(self, state_t, state_tp, op_t, operator):
        state = list(self.vars.keys())
        preconditions = operator.precondition
        effects = operator.effect
        prevail = operator.prevail
        f = gtsam.DecisionTreeFactor(op_t, self.generate_string(2, 1))
        for pre_var, pre_val in preconditions.items():
            state_var = state_t[state.index(pre_var)]
            cardinality = state_var[1]
            state_f = gtsam.DecisionTreeFactor(state_var, self.generate_string(cardinality, pre_val))
            f *= state_f
        for eff_var, eff_val in effects.items():
            state_var = state_tp[state.index(eff_var)]
            cardinality = state_var[1]
            state_f = gtsam.DecisionTreeFactor(state_var, self.generate_string(cardinality, eff_val))
            f *= state_f
        if prevail:
            for prev_var, prev_val in prevail.items():
                state_var_t = state_t[state.index(prev_var)]
                state_var_tp = state_tp[state.index(prev_var)]
                cardinality = state_var[1]
                state_t_f = gtsam.DecisionTreeFactor(state_var_t, self.generate_string(cardinality, prev_val))
                state_tp_f = gtsam.DecisionTreeFactor(state_var_tp, self.generate_string(cardinality, prev_val))
                f *= state_t_f
                f *= state_tp_f
        return f

    def valid(self, values, mutex):
        assert len(values) == len(mutex)
        count = 0
        for v, m in zip(values, mutex):
            if v == m:
                count += 1
            if count > 1:
                return "0"
        return "1"

    def generate_mutex_factor(self, state_t):
        state = list(self.vars.keys())
        mutex_variables = []
        mutex_values = []
        
        for mutex_group in self.mutex_groups:
            var_group = []
            val_group = []
            for var, val in mutex_group:
                state_var = state_t[state.index(var)]
                var_group.append(state_var)
                val_group.append(val)
            mutex_variables.append(var_group)
            mutex_values.append(val_group)
        
        f = gtsam.DecisionTreeFactor()
        for mutex_var, mutex_val in zip(mutex_variables, mutex_values):
            input = []
            for var in mutex_var:
                input.append(list(range(var[1])))
            prods = list(product(*input))
            mutex_string_list = []
            for prod in prods:
                is_valid = self.valid(prod, mutex_val)
                mutex_string_list.append(is_valid)
            mutex_string = ' '.join(mutex_string_list)
            f = f*gtsam.DecisionTreeFactor(mutex_var, mutex_string)
        return f

In [43]:
sas = SAS()
sas_dir = "sas/block_example.sas"
sas.read_file(sas_dir)
converter = SASToGTSAM(sas)

In [44]:
names = converter.ops_names # 1, 43, 66, 111, 131, 179

In [45]:
# def plan(k):
#     if k < 1:
#         return "There should be at least one state"
#     states = []
#     operators = []
#     mutex_factors = []
#     op_factors = []
#     for i in range(k-1):
#         state_t = converter.generate_state(i)
#         operators_t = converter.generate_operator(i)
#         states.append(state_t)
#         operators.append(operators_t)
#     last_state = converter.generate_state(k-1)
#     states.append(last_state)
#     for j in range(len(states)-1):
#         mutex_factor = converter.generate_mutex_factor(states[j])
#         mutex_factors.append(mutex_factor)
#         # op_group = []
#         for op_t, op in zip(operators[j], converter.ops):
#             op_factor = converter.generate_op_factor(states[j], states[j+1], op_t, op)
#             # op_group.append(op_factor)
#             op_factors.append(op_factor)
#         # op_factors.append(op_group)
#     last_mutex_factor = converter.generate_mutex_factor(states[-1])
#     mutex_factors.append(last_mutex_factor)
#     initial_factor = converter.generate_initial_factor()
#     goal_factor = converter.generate_goal_factor(states[-1])
#     return states, initial_factor,  goal_factor, mutex_factors, op_factors

In [46]:
# states, initial_factor,  goal_factor, mutex_factors, op_factors = plan(7)

In [148]:
# graph.push_back(op_factors[1])
# graph.push_back(op_factors[43])
# graph.push_back(op_factors[66])
# graph.push_back(op_factors[111])
# graph.push_back(op_factors[131])
# graph.push_back(op_factors[179])
# graph.push_back(goal_factor)
# graph.push_back(initial_factor)

In [149]:
a = (0, 2)
b = (1, 2)
c = (2, 2)
d = (3, 2)

In [157]:
f_and = DecisionTreeFactor([a, b], "0 0 0 1")
f_bet = DecisionTreeFactor([b, c], "0 0 0 0")
f_or = DecisionTreeFactor([b, c], "0 1 1 1")

In [158]:
f_and

0,1,value
0,0,0
0,1,0
1,0,0
1,1,1


In [159]:
f_or

1,2,value
0,0,0
0,1,1
1,0,1
1,1,1


In [160]:
f_and * f_or

0,1,2,value
0,0,0,0
0,0,1,0
0,1,0,0
0,1,1,0
1,0,0,0
1,0,1,0
1,1,0,1
1,1,1,1


In [154]:
for val in f_and.enumerate():
    print(val)

(DiscreteValues{0: 0, 1: 0}, 0.0)
(DiscreteValues{0: 0, 1: 1}, 0.0)
(DiscreteValues{0: 1, 1: 0}, 0.0)
(DiscreteValues{0: 1, 1: 1}, 1.0)


In [None]:
f_and.

In [None]:
key_set = set()


In [49]:
op_var = [a,b,c]
input = []
for var in op_var:
    input.append(list(range(var[1])))
prods = list(product(*input))

In [50]:
prods

[(0, 0, 0),
 (0, 0, 1),
 (0, 1, 0),
 (0, 1, 1),
 (1, 0, 0),
 (1, 0, 1),
 (1, 1, 0),
 (1, 1, 1)]

In [91]:
converter.mutex_groups

[[[1, 0], [0, 0], [6, 1], [7, 1], [8, 1]],
 [[2, 0], [0, 1], [6, 0], [7, 2], [8, 2]],
 [[3, 0], [0, 2], [6, 2], [7, 0], [8, 3]],
 [[4, 0], [0, 3], [6, 3], [7, 3], [8, 0]],
 [[5, 0], [0, 0], [6, 0], [7, 0], [8, 0]]]

In [51]:
state_t = converter.generate_state(0)

In [52]:
state_t

[(0, 5), (1, 2), (2, 2), (3, 2), (4, 2), (5, 2), (6, 5), (7, 5), (8, 5)]

In [53]:
state = list(converter.vars.keys())

In [54]:
state

[0, 1, 2, 3, 4, 5, 6, 7, 8]

In [55]:
mutex_variables = []
mutex_values = []
for mutex_group in converter.mutex_groups:
    var_group = []
    val_group = []
    for var, val in mutex_group:
        state_var = state_t[state.index(var)]
        var_group.append(state_var)
        val_group.append(val)
    mutex_variables.append(var_group)
    mutex_values.append(val_group)

In [56]:
mutex_variables

[[(1, 2), (0, 5), (6, 5), (7, 5), (8, 5)],
 [(2, 2), (0, 5), (6, 5), (7, 5), (8, 5)],
 [(3, 2), (0, 5), (6, 5), (7, 5), (8, 5)],
 [(4, 2), (0, 5), (6, 5), (7, 5), (8, 5)],
 [(5, 2), (0, 5), (6, 5), (7, 5), (8, 5)]]

In [57]:
mutex_values

[[0, 0, 1, 1, 1],
 [0, 1, 0, 2, 2],
 [0, 2, 2, 0, 3],
 [0, 3, 3, 3, 0],
 [0, 0, 0, 0, 0]]

In [58]:
for mutex_var, mutex_val in zip(mutex_variables, mutex_values):
    input = []
    for var in mutex_var:
        input.append(list(range(var[1])))
input

[[0, 1], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4]]

In [60]:
prods = list(product(*input))

In [93]:
len(prods)

1250

In [None]:
mutex_string_list = []
for prod in prods:
    is_valid = converter.valid(prod, mutex_val)
    mutex_string_list.append(is_valid)
mutex_string = ' '.join(mutex_string_list)
f = f*gtsam.DecisionTreeFactor(mutex_var, mutex_string)

In [155]:
help(CartesianProduct)

NameError: name 'CartesianProduct' is not defined