In [5]:
import sys
sys.path.append('../scripts/')
from puddle_world import *
import itertools
import collections
%matplotlib widget

In [6]:
class PolicyEvaluator:
    # 引数追加
    def __init__(self, widths, goal, time_interval, sampling_num, lowerleft=np.array([-4, -4]).T, upperright=np.array([4, 4]).T):
        self.pose_min = np.r_[lowerleft, 0]
        self.pose_max = np.r_[upperright, math.pi*2]
        self.widths = widths
        self.goal = goal

        self.index_nums = ((self.pose_max - self.pose_min) /
                           self.widths).astype(int)
        nx, ny, nt = self.index_nums
        self.indexes = list(itertools.product(range(nx), range(ny), range(nt)))

        self.value_function, self.final_state_flags = self.init_value_function()
        self.policy = self.init_policy()

        self.actions = list(set([tuple(self.policy[i]) for i in self.indexes]))

        self.state_transition_probs = self.init_state_transition_probs(
            time_interval, sampling_num)

    def init_state_transition_probs(self, time_interval, sampling_num):

        dx = np.linspace(0.001, self.widths[0]*0.999, sampling_num)
        dy = np.linspace(0.001, self.widths[1]*0.999, sampling_num)
        dt = np.linspace(0.001, self.widths[2]*0.999, sampling_num)
        samples = list(itertools.product(dx, dy, dt))

        tmp = {}
        for a in self.actions:
            for i_t in range(self.index_nums[2]):
                transitions = []
                for s in samples:
                    before = np.array(
                        [s[0], s[1], s[2] + i_t*self.widths[2]]).T + self.pose_min
                    before_index = np.array([0, 0, i_t]).T

                    after = IdealRobot.state_transition(
                        a[0], a[1], time_interval, before)
                    after_index = np.floor(
                        (after - self.pose_min)/self.widths).astype(int)

                    transitions.append(
                        after_index - before_index)

                unique, count = np.unique(
                    transitions, axis=0, return_counts=True)
                probs = [c/sampling_num**3 for c in count]
                tmp[a, i_t] = list(zip(unique, probs))

        return tmp

    def init_policy(self):
        tmp = np.zeros(np.r_[self.index_nums, 2])
        for index in self.indexes:
            center = self.pose_min + self.widths * \
                (np.array(index).T + 0.5)
            tmp[index] = PuddleIgnoreAgent.policy(center, self.goal)

        return tmp

    def init_value_function(self):
        v = np.empty(self.index_nums)
        f = np.zeros(self.index_nums)

        for index in self.indexes:
            f[index] = self.final_state(np.array(index).T)
            v[index] = self.goal.value if f[index] else -100.0

        return v, f

    def final_state(self, index):
        x_min, y_min, _ = self.pose_min + self.widths*index
        x_max, y_max, _ = self.pose_min + self.widths * (index + 1)

        corners = [[x_min, y_min, _], [x_min, y_max, _], [
            x_max, y_min, _], [x_max, y_max, _]]
        return all([self.goal.inside(np.array(c).T) for c in corners])

In [7]:
pe = PolicyEvaluator(np.array([0.2, 0.2, math.pi/18]).T, Goal(-3, -3), 0.1, 10)
pe.state_transition_probs

{((1.0, 0.0), 0): [(array([0, 0, 0]), 0.455),
  (array([0, 1, 0]), 0.045),
  (array([1, 0, 0]), 0.455),
  (array([1, 1, 0]), 0.045)],
 ((1.0, 0.0), 1): [(array([0, 0, 0]), 0.415),
  (array([0, 1, 0]), 0.085),
  (array([1, 0, 0]), 0.415),
  (array([1, 1, 0]), 0.085)],
 ((1.0, 0.0), 2): [(array([0, 0, 0]), 0.401),
  (array([0, 1, 0]), 0.129),
  (array([1, 0, 0]), 0.359),
  (array([1, 1, 0]), 0.111)],
 ((1.0, 0.0), 3): [(array([0, 0, 0]), 0.42),
  (array([0, 1, 0]), 0.18),
  (array([1, 0, 0]), 0.28),
  (array([1, 1, 0]), 0.12)],
 ((1.0, 0.0), 4): [(array([0, 0, 0]), 0.384),
  (array([0, 1, 0]), 0.236),
  (array([1, 0, 0]), 0.236),
  (array([1, 1, 0]), 0.144)],
 ((1.0, 0.0), 5): [(array([0, 0, 0]), 0.42),
  (array([0, 1, 0]), 0.28),
  (array([1, 0, 0]), 0.18),
  (array([1, 1, 0]), 0.12)],
 ((1.0, 0.0), 6): [(array([0, 0, 0]), 0.401),
  (array([0, 1, 0]), 0.359),
  (array([1, 0, 0]), 0.129),
  (array([1, 1, 0]), 0.111)],
 ((1.0, 0.0), 7): [(array([0, 0, 0]), 0.415),
  (array([0, 1, 0]), 0.4