# Planning Notebook

In [1]:
import math
import unittest
import numpy as np
from itertools import product
import tqdm
from tqdm import tqdm_notebook
import copy
import pickle

# import gtsam
import gtsam
from gtsam import *
from gtsam.utils.test_case import GtsamTestCase

# import gtbook
import gtbook
from gtbook.display import *
from gtbook.discrete import *

# import local package
import gtsam_planner
from gtsam_planner import *

# import parser
import SASParser
from SASParser import SAS, Operator

variables = Variables()
def pretty(obj): 
    return gtbook.display.pretty(obj, variables)

import graphviz
class show(graphviz.Source):
    """ Display an object with a dot method as a graph."""

    def __init__(self, obj):
        """Construct from object with 'dot' method."""
        # This small class takes an object, calls its dot function, and uses the
        # resulting string to initialize a graphviz.Source instance. This in turn
        # has a _repr_mimebundle_ method, which then renders it in the notebook.
        super().__init__(obj.dot())

In [2]:
class SASToGTSAM():
    def __init__(self, sas):
        self.sas = sas
        self.init = sas.initial_state
        self.goal = sas.goal
        self.vars = self.sas.variables
        self.ops = self.sas.operators
        self.mutex_groups = self.sas.mutex_group
        self.ops_names = []
        for op in self.ops:
            self.ops_names.append(op.name)
        self.state_keys = list(self.vars.keys())
    
    def generate_state(self, timestep):
        state = []
        for var, val in self.vars.items():
            state_var = variables.discrete(str(var)+"_"+str(timestep), val)
            state.append(state_var)
        return state
    
    def generate_operator_key(self, timestep):
        op_var = variables.discrete("op_"+str(timestep), self.ops_names)
        return op_var
    
    def generate_initial_factor(self, initial_state):
        keys = gtsam.DiscreteKeys()
        for key in initial_state:
            keys.push_back(key)
        init_values = list(self.init.values())
        init_f = gtsam_planner.MultiValueConstraint(keys, init_values)
        return init_f

    def generate_goal_factor(self, goal_state):
        keys = gtsam.DiscreteKeys()
        vals = []
        for goal_var, goal_val in self.goal.items():
            keys.push_back(goal_state[self.state_keys.index(goal_var)])
            vals.append(goal_val)
        goal_f = gtsam_planner.MultiValueConstraint(keys, vals)
        return goal_f
    
    def generate_op_null(self, state_t, state_tp, operator):
        vals = []
        keys = set()
        multi_keys = gtsam.DiscreteKeys()
        for pre_var, pre_val in operator.precondition.items():
            key = state_t[self.state_keys.index(pre_var)]
            keys.add(key)
            if pre_val == -1:
                continue
            multi_keys.push_back(key)
            vals.append(pre_val)
        
        for eff_var, eff_val in operator.effect.items():
            key = state_tp[self.state_keys.index(eff_var)]
            keys.add(key)
            if eff_val == -1:
                continue
            multi_keys.push_back(key)
            vals.append(eff_val)
        
        if operator.num_prevail > 0:
            for var, val in operator.prevail.items():
                key_t = state_t[self.state_keys.index(var)]
                key_tp = state_tp[self.state_keys.index(var)]
                keys.add(key_t)
                keys.add(key_tp)
                multi_keys.push_back(key_t)
                multi_keys.push_back(key_tp)
                vals.extend([val,val])
        assert len(keys) % 2 == 0
        null_keys = gtsam.DiscreteKeys()
        for var in state_t+state_tp:
            if var not in keys:
                null_keys.push_back(var)
                keys.add(var)

        op_f = gtsam_planner.MultiValueConstraint(multi_keys, vals)
        null_f = gtsam_planner.NullConstraint(null_keys)
        return op_f, null_f, keys
    
    def generate_null_constraint(self, state_t, state_tp):
        """
        true if state_t and state_tp is same
        false otherwise
        """
        keys = gtsam.DiscreteKeys()
        for key in (state_t+state_tp):
            keys.push_back(key)
        null_constraint = gtsam_planner.NullOperatorConstraint(keys)
        return null_constraint        

    def generate_mutex_factor(self, state_t):
        state = list(self.vars.keys())
        mutex_variables = []
        mutex_values = []
        
        for mutex_group in self.mutex_groups:
            var_group = []
            val_group = []
            for var, val in mutex_group:
                state_var = state_t[state.index(var)]
                var_group.append(state_var)
                val_group.append(val)
            mutex_variables.append(var_group)
            mutex_values.append(val_group)
        
        factors = []
        for mutex_var, mutex_val in zip(mutex_variables, mutex_values):
            keys = gtsam.DiscreteKeys()
            for var in mutex_var:
                keys.push_back(var)
            mutex = gtsam_planner.MutexConstraint(keys, mutex_val)
            factors.append(mutex)
        return factors
    
    def generate_op_factor(self, state_t, state_tp, op_key):
        op_consts = []
        null_consts = []
        keys_set = set()
        keys_set.add(op_key)
        for op in self.ops:
            op_const, null_const, keys = self.generate_op_null(state_t, state_tp, op)
            op_consts.append(op_const)
            null_consts.append(null_const)
            keys_set = keys_set.union(keys)

        keys = gtsam.DiscreteKeys()
        for key in keys_set:
            keys.push_back(key)

        op_factor = gtsam_planner.OperatorAddConstraint(op_key, keys, op_consts, null_consts)
        return op_factor

In [3]:
sas = SAS()
sas_dir = "sas/gripper_example.sas"
sas.read_file(sas_dir)
converter = SASToGTSAM(sas)

In [4]:
def plan(plan_length):
    for k in range(2, plan_length):
        print(k)
        states = []
        mutex_factors = []
        op_factors = []

        for i in range(k):
            # generate state
            state_t = converter.generate_state(i)
            states.append(state_t)
            # generate mutex factor for the state
            mutex_factor_t = converter.generate_mutex_factor(state_t)
            mutex_factors.append(mutex_factor_t)
        for j in range(len(states)-1):
            op_key = converter.generate_operator_key(j)
            op_factor = converter.generate_op_factor(states[j], states[j+1], op_key)
            op_factors.append(op_factor)
        initial_factor = converter.generate_initial_factor(states[0])
        goal_factor = converter.generate_goal_factor(states[-1])

        graph = gtsam.DiscreteFactorGraph()
        for m_factor in mutex_factors:
            for f in m_factor:
                graph.push_back(f)

        graph.push_back(goal_factor)
        graph.push_back(initial_factor)

        for op_factor in op_factors:
            graph.push_back(op_factor)

        val = graph.optimize()
        if graph(val) == 0.0:
            del graph
            continue
        else:
            return graph, val, k
    return "longer plan length?"

In [5]:
graph, val, k = plan(20)

2
3
4
5
6
7
8
9
10
11
12


In [6]:
print(graph)


size: 61
factor 0: MutexConstraint on 513 513 511 512 
factor 1: MutexConstraint on 514 514 511 512 
factor 2: MutexConstraint on 515 515 511 512 
factor 3: MutexConstraint on 516 516 511 512 
factor 4: MutexConstraint on 520 520 518 519 
factor 5: MutexConstraint on 521 521 518 519 
factor 6: MutexConstraint on 522 522 518 519 
factor 7: MutexConstraint on 523 523 518 519 
factor 8: MutexConstraint on 527 527 525 526 
factor 9: MutexConstraint on 528 528 525 526 
factor 10: MutexConstraint on 529 529 525 526 
factor 11: MutexConstraint on 530 530 525 526 
factor 12: MutexConstraint on 534 534 532 533 
factor 13: MutexConstraint on 535 535 532 533 
factor 14: MutexConstraint on 536 536 532 533 
factor 15: MutexConstraint on 537 537 532 533 
factor 16: MutexConstraint on 541 541 539 540 
factor 17: MutexConstraint on 542 542 539 540 
factor 18: MutexConstraint on 543 543 539 540 
factor 19: MutexConstraint on 544 544 539 540 
factor 20: MutexConstraint on 548 548 546 547 
factor 21: Mu

In [7]:
op_consts = []
for i in range(graph.size()-1, graph.size()-k, -1):
    op_consts.append(graph.at(i))
val_list = []
for op_const in reversed(op_consts):
    print(converter.ops_names[val[op_const.operatorKey()]])

pick ball1 rooma left
pick ball2 rooma right
move rooma roomb
drop ball1 roomb left
drop ball2 roomb right
move roomb rooma
pick ball3 rooma left
pick ball4 rooma right
move rooma roomb
drop ball3 roomb left
drop ball4 roomb right
