In [1]:
import sys      ###policyevaluator4header
sys.path.append('../scripts/')
from puddle_world import *
import itertools 
import collections #追加

In [2]:
class PolicyEvaluator:   ###policyevaluator4
    def __init__(self, widths, goal, time_interval, sampling_num, lowerleft=np.array([-4, -4]).T, upperright=np.array([4, 4]).T): #引数追加
        self.pose_min = np.r_[lowerleft, 0]
        self.pose_max = np.r_[upperright, math.pi*2]
        self.widths = widths
        self.goal = goal
        
        self.index_nums = ((self.pose_max - self.pose_min)/self.widths).astype(int)
        nx, ny, nt = self.index_nums
        self.indexes = list(itertools.product(range(nx), range(ny), range(nt)))
        
        self.value_function, self.final_state_flags =  self.init_value_function() 
        self.policy = self.init_policy()
        self.actions = list(set([tuple(self.policy[i]) for i in self.indexes])) #追加（policyの行動をsetにすることで重複削除し、リスト化）
        
        self.state_transition_probs = self.init_state_transition_probs(time_interval, sampling_num) #追加
    
    def init_state_transition_probs(self, time_interval, sampling_num): #追加
        ###セルの中の座標を均等にsampling_num**3点サンプリング###
        dx = np.linspace(0.001, self.widths[0]*0.999, sampling_num) #隣のセルにはみ出さないように端を避ける
        dy = np.linspace(0.001, self.widths[1]*0.999, sampling_num)
        dt = np.linspace(0.001, self.widths[2]*0.999, sampling_num)
        samples = list(itertools.product(dx, dy, dt))
        
        ###各行動、各方角でサンプリングした点を移動してインデックスの増分を記録###
        tmp = {}
        for a in self.actions:
            for i_t in range(self.index_nums[2]):
                transitions = []
                for s in samples:
                    before = np.array([s[0], s[1], s[2] + i_t*self.widths[2]]).T + self.pose_min  #遷移前の姿勢
                    before_index = np.array([0, 0, i_t]).T                                                      #遷移前のインデックス
                
                    after = IdealRobot.state_transition(a[0], a[1], time_interval, before)   #遷移後の姿勢
                    after_index = np.floor((after - self.pose_min)/self.widths).astype(int)   #遷移後のインデックス
                    
                    transitions.append(after_index - before_index)                                  #インデックスの差分を追加
                    
                unique, count = np.unique(transitions, axis=0, return_counts=True)   #集計（どのセルへの遷移が何回か）
                probs = [c/sampling_num**3 for c in count]                   #サンプル数で割って確率にする
                tmp[a,i_t] = list(zip(unique, probs))
                
        return tmp
        
    def init_policy(self):
        tmp = np.zeros(np.r_[self.index_nums,2]) #制御出力が2次元なので、配列の次元を4次元に
        for index in self.indexes:
            center = self.pose_min + self.widths*(np.array(index).T + 0.5)  #セルの中心の座標
            tmp[index] = PuddleIgnoreAgent.policy(center, self.goal)
            
        return tmp
        
    def init_value_function(self): 
        v = np.empty(self.index_nums) #全離散状態を要素に持つ配列を作成
        f = np.zeros(self.index_nums) 
        
        for index in self.indexes:
            f[index] = self.final_state(np.array(index).T)
            v[index] = self.goal.value if f[index] else -100.0
                
        return v, f
        
    def final_state(self, index):
        x_min, y_min, _ = self.pose_min + self.widths*index          #xy平面で左下の座標
        x_max, y_max, _ = self.pose_min + self.widths*(index + 1) #右上の座標（斜め上の離散状態の左下の座標）
        
        corners = [[x_min, y_min, _], [x_min, y_max, _], [x_max, y_min, _], [x_max, y_max, _] ] #4隅の座標
        return all([self.goal.inside(np.array(c).T) for c in corners ])

In [3]:
pe = PolicyEvaluator(np.array([0.2, 0.2, math.pi/18]).T, Goal(-3,-3), 0.1, 10)   ##policyevaluator4exec
pe.state_transition_probs

{((0.0, -2.0), 0): [(array([ 0,  0, -2]), 0.20000000000000001),
  (array([ 0,  0, -1]), 0.80000000000000004)],
 ((0.0, -2.0), 1): [(array([ 0,  0, -2]), 0.20000000000000001),
  (array([ 0,  0, -1]), 0.80000000000000004)],
 ((0.0, -2.0), 2): [(array([ 0,  0, -2]), 0.20000000000000001),
  (array([ 0,  0, -1]), 0.80000000000000004)],
 ((0.0, -2.0), 3): [(array([ 0,  0, -2]), 0.20000000000000001),
  (array([ 0,  0, -1]), 0.80000000000000004)],
 ((0.0, -2.0), 4): [(array([ 0,  0, -2]), 0.20000000000000001),
  (array([ 0,  0, -1]), 0.80000000000000004)],
 ((0.0, -2.0), 5): [(array([ 0,  0, -2]), 0.20000000000000001),
  (array([ 0,  0, -1]), 0.80000000000000004)],
 ((0.0, -2.0), 6): [(array([ 0,  0, -2]), 0.20000000000000001),
  (array([ 0,  0, -1]), 0.80000000000000004)],
 ((0.0, -2.0), 7): [(array([ 0,  0, -2]), 0.20000000000000001),
  (array([ 0,  0, -1]), 0.80000000000000004)],
 ((0.0, -2.0), 8): [(array([ 0,  0, -2]), 0.20000000000000001),
  (array([ 0,  0, -1]), 0.80000000000000004)],
 