# Planning Notebook

In [1]:
import math
import unittest
import numpy as np
from itertools import product
import tqdm
from tqdm import tqdm_notebook
import copy
import pickle

# import gtsam
import gtsam
from gtsam import *
from gtsam.utils.test_case import GtsamTestCase

# import gtbook
import gtbook
from gtbook.display import *
from gtbook.discrete import *

# import local package
import gtsam_planner
from gtsam_planner import *

# import parser
import SASParser
from SASParser import SAS, Operator

variables = Variables()
def pretty(obj): 
    return gtbook.display.pretty(obj, variables)

import graphviz
class show(graphviz.Source):
    """ Display an object with a dot method as a graph."""

    def __init__(self, obj):
        """Construct from object with 'dot' method."""
        # This small class takes an object, calls its dot function, and uses the
        # resulting string to initialize a graphviz.Source instance. This in turn
        # has a _repr_mimebundle_ method, which then renders it in the notebook.
        super().__init__(obj.dot())

In [2]:
class SASToGTSAM():
    def __init__(self, sas):
        self.sas = sas
        self.init = sas.initial_state
        self.goal = sas.goal
        self.vars = self.sas.variables
        self.ops = self.sas.operators
        self.mutex_groups = self.sas.mutex_group
        self.ops_names = []
        for op in self.ops:
            self.ops_names.append(op.name)
        self.state_keys = list(self.vars.keys())
    
    def generate_state(self, timestep):
        state = []
        for var, val in self.vars.items():
            state_var = variables.discrete(str(var)+"_"+str(timestep), val)
            state.append(state_var)
        return state
    
    def generate_operator_key(self, timestep):
        op_var = variables.discrete("op_"+str(timestep), self.ops_names)
        return op_var
    
    def generate_initial_factor(self, initial_state):
        keys = gtsam.DiscreteKeys()
        for key in initial_state:
            keys.push_back(key)
        init_values = list(self.init.values())
        init_f = gtsam_planner.MultiValueConstraint(keys, init_values)
        return init_f

    def generate_goal_factor(self, goal_state):
        keys = gtsam.DiscreteKeys()
        vals = []
        for goal_var, goal_val in self.goal.items():
            keys.push_back(goal_state[self.state_keys.index(goal_var)])
            vals.append(goal_val)
        goal_f = gtsam_planner.MultiValueConstraint(keys, vals)
        return goal_f
    
    def generate_op_null(self, state_t, state_tp, operator):
        vals = []
        keys = set()
        multi_keys = gtsam.DiscreteKeys()
        for pre_var, pre_val in operator.precondition.items():
            key = state_t[self.state_keys.index(pre_var)]
            keys.add(key)
            if pre_val == -1:
                continue
            multi_keys.push_back(state_t[self.state_keys.index(pre_var)])
            vals.append(pre_val)
        
        for eff_var, eff_val in operator.effect.items():
            key = state_tp[self.state_keys.index(eff_var)]
            keys.add(key)
            if eff_val == -1:
                continue
            multi_keys.push_back(key)
            vals.append(eff_val)
        
        assert len(keys) % 2 == 0
        null_keys = gtsam.DiscreteKeys()
        for var in state_t+state_tp:
            if var not in keys:
                null_keys.push_back(var)
                keys.add(var)

        op_f = gtsam_planner.MultiValueConstraint(multi_keys, vals)
        null_f = gtsam_planner.NullConstraint(null_keys)
        return op_f, null_f, keys
    
    def generate_null_constraint(self, state_t, state_tp):
        """
        true if state_t and state_tp is same
        false otherwise
        """
        keys = gtsam.DiscreteKeys()
        for key in (state_t+state_tp):
            keys.push_back(key)
        null_constraint = gtsam_planner.NullOperatorConstraint(keys)
        return null_constraint        

    def generate_mutex_factor(self, state_t):
        state = list(self.vars.keys())
        mutex_variables = []
        mutex_values = []
        
        for mutex_group in self.mutex_groups:
            var_group = []
            val_group = []
            for var, val in mutex_group:
                state_var = state_t[state.index(var)]
                var_group.append(state_var)
                val_group.append(val)
            mutex_variables.append(var_group)
            mutex_values.append(val_group)
        
        factors = []
        for mutex_var, mutex_val in zip(mutex_variables, mutex_values):
            keys = gtsam.DiscreteKeys()
            for var in mutex_var:
                keys.push_back(var)
            mutex = gtsam_planner.MutexConstraint(keys, mutex_val)
            factors.append(mutex)
        return factors
    
    def generate_op_factor(self, state_t, state_tp, op_key):
        op_consts = []
        null_consts = []
        keys_set = set()
        keys_set.add(op_key)
        for op in self.ops:
            op_const, null_const, keys = self.generate_op_null(state_t, state_tp, op)
            op_consts.append(op_const)
            null_consts.append(null_const)
            keys_set = keys_set.union(keys)

        keys = gtsam.DiscreteKeys()
        for key in keys_set:
            keys.push_back(key)

        op_factor = gtsam_planner.OperatorAddConstraint(op_key, keys, op_consts, null_consts)
        return op_factor

In [3]:
sas = SAS()
sas_dir = "sas/block_example.sas"
sas.read_file(sas_dir)
converter = SASToGTSAM(sas)

In [4]:
def plan(plan_length):
    for k in range(2, plan_length):
        print(k)
        states = []
        mutex_factors = []
        op_factors = []

        for i in range(k):
            # generate state
            state_t = converter.generate_state(i)
            states.append(state_t)
            # generate mutex factor for the state
            mutex_factor_t = converter.generate_mutex_factor(state_t)
            mutex_factors.append(mutex_factor_t)
        for j in range(len(states)-1):
            op_key = converter.generate_operator_key(j)
            op_factor = converter.generate_op_factor(states[j], states[j+1], op_key)
            op_factors.append(op_factor)
        initial_factor = converter.generate_initial_factor(states[0])
        goal_factor = converter.generate_goal_factor(states[-1])

        graph = gtsam.DiscreteFactorGraph()
        for m_factor in mutex_factors:
            for f in m_factor:
                graph.push_back(f)

        graph.push_back(goal_factor)
        graph.push_back(initial_factor)

        for op_factor in op_factors:
            graph.push_back(op_factor)

        val = graph.optimize()
        if graph(val) == 0.0:
            del graph
            continue
        else:
            return graph, val, k
    return "longer plan length?"

In [5]:
graph, val, k = plan(10)

2
3
4
5
6
7


In [8]:
print(graph)


size: 43
factor 0: MutexConstraint on 196 195 201 202 203 
factor 1: MutexConstraint on 197 195 201 202 203 
factor 2: MutexConstraint on 198 195 201 202 203 
factor 3: MutexConstraint on 199 195 201 202 203 
factor 4: MutexConstraint on 200 195 201 202 203 
factor 5: MutexConstraint on 205 204 210 211 212 
factor 6: MutexConstraint on 206 204 210 211 212 
factor 7: MutexConstraint on 207 204 210 211 212 
factor 8: MutexConstraint on 208 204 210 211 212 
factor 9: MutexConstraint on 209 204 210 211 212 
factor 10: MutexConstraint on 214 213 219 220 221 
factor 11: MutexConstraint on 215 213 219 220 221 
factor 12: MutexConstraint on 216 213 219 220 221 
factor 13: MutexConstraint on 217 213 219 220 221 
factor 14: MutexConstraint on 218 213 219 220 221 
factor 15: MutexConstraint on 223 222 228 229 230 
factor 16: MutexConstraint on 224 222 228 229 230 
factor 17: MutexConstraint on 225 222 228 229 230 
factor 18: MutexConstraint on 226 222 228 229 230 
factor 19: MutexConstraint on 2

In [6]:
op_consts = []
for i in range(graph.size()-1, graph.size()-k, -1):
    op_consts.append(graph.at(i))
val_list = []
for op_const in reversed(op_consts):
    print(converter.ops_names[val[op_const.operatorKey()]])

pick-up b
stack b a
pick-up c
stack c b
pick-up d
stack d c


In [24]:
class TestSASToGTSAM(GtsamTestCase):
    """Tests for Single Value Constraints"""

    def setUp(self):
        sas = SAS()
        sas_dir = "sas/block_example.sas"
        sas.read_file(sas_dir)
        self.converter = SASToGTSAM(sas)
        self.init_state = self.converter.generate_state(0)
        self.next_state = self.converter.generate_state(1)
        self.op_key = self.converter.generate_operator_key(0)

        input = []
        for _, vars in self.converter.vars.items():
            input.append(list(range(len(vars))))
        # tried this but this crashes the kernel
        # for _, vars in self.converter.vars.items():
        #     input.append(list(range(len(vars))))
        self.prods = list(product(*input))
    
    def createVal(self, states, prod):
        values = gtsam.DiscreteValues()
        for state, val in zip(states, prod):
            values[state[0]] = val
        return values

    def createOperatorVal(self, state1, state2, op):
        state = list(converter.vars.keys())
        values = gtsam.DiscreteValues()
        for var, val in op.precondition.items():
            if val == -1:
                continue
            values[state1[state.index(var)][0]] = val
        for var, val in op.effect.items():
            if val == -1:
                continue
            values[state2[state.index(var)][0]] = val
        return values
    
    def valid(self, prod, mutex_group):
        count = 0
        for var, val in mutex_group:
            if prod[var] == val:
                count += 1
            if count > 1:
                return 0.0
        return 1.0

    def test_generateState(self):
        assert len(self.init_state) == 9

    def test_generateOperatorState(self):
        # there are 32 possible operators
        assert self.op_key[1] == 32

    def test_generateInitial(self):
        state = list(converter.vars.keys())
        initial_factor = self.converter.generate_initial_factor(self.init_state)
        total = 0.0
        for prod in self.prods:
            values = self.createVal(self.init_state, prod)
            output = initial_factor(values)
            if output == 1:
                for var, val in self.converter.init.items():
                    assert values[self.init_state[state.index(var)][0]] == val
            total += output
        assert total == 1

    def test_generateGoal(self):
        state = list(converter.vars.keys())
        goal_factor = self.converter.generate_goal_factor(self.next_state)
        total = 0.0
        for prod in self.prods:
            values = self.createVal(self.next_state, prod)
            output = goal_factor(values)
            if output == 1:
                for var, val in self.converter.goal.items():
                    assert values[self.next_state[state.index(var)][0]] == val
            total += output
        assert total > 1

    def test_generateOperatorConstraint(self):
        for operator in converter.ops:
            op_factor = self.converter.generate_op_constraint(self.init_state, self.next_state, operator)
            values = self.createOperatorVal(self.init_state, self.next_state, operator)
            assert op_factor(values) == 1.0
    
    # def test_generateMutex(self):
    #     mutex_factors = self.converter.generate_mutex_factor(self.init_state)
    #     for prod in tqdm_notebook(self.prods, desc='possible states'):
    #         for mutex_factor, mutex_group in zip(mutex_factors, converter.mutex_groups):
    #             check_valid = self.valid(prod, mutex_group)
    #             values = self.createVal(self.init_state, prod)
    #             factor_valid = mutex_factor(values)
    #             tree_factor = mutex_factor.toDecisionTreeFactor()
    #             tree_valid = tree_factor(values)
    #             assert tree_valid == factor_valid == check_valid

    # def test_generateOperatorFactor(self):
        

In [25]:
# unittest.main(argv=[''], verbosity=2, exit=False)