In [1]:
import numpy as np
import pandas as pd

In [2]:
# set observation model
observation_model = pd.DataFrame()
observation_model[1] = [0.9, 0.1, 0]
observation_model[2] = [0.1, 0.9, 0]
observation_model["end"] = [0, 0, 1]
observation_model.index = ["non-terminal in third col", "other non-terminal", "terminal"]
observation_model

Unnamed: 0,1,2,end
non-terminal in third col,0.9,0.1,0
other non-terminal,0.1,0.9,0
terminal,0.0,0.0,1


In [20]:
# initialize belief state
BELIEF_STATE = np.full((3,4), (1/9))
# set terminal states
BELIEF_STATE[0][3] = 0
BELIEF_STATE[1][3] = 0
# set non-accessible state
BELIEF_STATE[1][1] = np.nan
BELIEF_STATE

# BELIEF_STATE = np.zeros((3,4))
# BELIEF_STATE[2][0] = 1
# BELIEF_STATE[1][1] = np.nan
# BELIEF_STATE

array([[0.11111111, 0.11111111, 0.11111111, 0.        ],
       [0.11111111,        nan, 0.11111111, 0.        ],
       [0.11111111, 0.11111111, 0.11111111, 0.11111111]])

In [4]:
def is_valid(state):
    row,col = state
    num_rows, num_cols = BELIEF_STATE.shape
    empty_states = [(1,1)]
    if col<0 or col>=num_cols or row<0 or row>=num_rows:
        return False
    if state in empty_states:
        return False
    return True

def move(s, action):
    if action=='up' and s[1]+1 <= 2:
        s_prime = (s[0], s[1]+1)
        return s_prime
        
    elif action=='left' and s[0]-1 >= 0:
        s_prime = (s[0]-1, s[1])
        return s_prime
        
    elif action=='down' and s[1]-1 >= 0:
        s_prime = (s[0], s[1]-1)
        return s_prime
        
    elif action=='right' and s[0]+1 <= 3:
        s_prime = (s[0]+1, s[1])
        return s_prime
    
    else:
        return s

def get_p_e_given_s_prime(e, new_state):
    if new_state[1] == 2:
        state_type = "non-terminal in third col"
    elif new_state==(0,3) or new_state==(1,3):
        state_type = "terminal"
    else:
        state_type = "other non-terminal"
    prob = observation_model.loc[state_type, e]
    return prob

def get_neighbors(state):
    row, col = state
    return [(row-1,col),(row+1,col),(row,col+1),(row,col-1)]

def get_prob_s_prime_given_a_s(s_prime, a, s):
  #todo: add case where s is equal to s_prime
    terminal_states = [(0,3), (1,3)]
    if s in get_neighbors(s_prime) and is_valid(s) and (s not in terminal_states):
        left_neighbor,right_neighbor,up_neighbor,down_neighbor = get_neighbors(s_prime)

        if a=="up":
            if s==left_neighbor or s==right_neighbor:
                return 0.1
            elif s==down_neighbor:
                return 0.8
            else:
                return 0

        if a=="down":
            if s==left_neighbor or s==right_neighbor:
                return 0.1
            elif s==up_neighbor:
                return 0.8
            else:
                return 0

        if a=="left":
            if s==up_neighbor or s==down_neighbor:
                return 0.1
            elif s==right_neighbor:
                return 0.8
            else:
                return 0

        if a=="right":
            if s==up_neighbor or s==down_neighbor:
                return 0.1
            elif s==left_neighbor:
                return 0.8
            else:
                return 0

    else:
        return 0

def normalize(VALID_STATES, BELIEF_STATE):
    value_total = 0
    for s in VALID_STATES:
        value_total += BELIEF_STATE[s[0]][s[1]]
    for s in VALID_STATES:
        BELIEF_STATE[s[0]][s[1]] /= value_total
    return BELIEF_STATE

def update_belief_state(s, a, e):
    s_prime = move(s, a)
    p_e_given_s_prime = get_p_e_given_s_prime(e, s_prime)
    sum_term = 0
    for neighbor in get_neighbors(s_prime):
        if is_valid(neighbor):
            sum_term += get_prob_s_prime_given_a_s(s_prime, a, neighbor) * BELIEF_STATE[neighbor[0]][neighbor[1]]
    return p_e_given_s_prime * sum_term

In [21]:
VALID_STATES = [(0,0), (0,1), (0,2), (0,3), (1,0), (1,2), (1,3), (2,0), (2,1), (2,2), (2,3)]
    
def pomdp(actions, observations, valid_states, belief_state):
    print(belief_state)
    for a,e in zip(actions, observations): 
        for s in valid_states:       
            belief_state[s[0]][s[1]] = update_belief_state(s, a, e)

        belief_state = normalize(valid_states, belief_state)
        print(belief_state)

ACTIONS = ["up", "right", "right", "right"]
OBSERVATIONS = [2,2,1,1]
pomdp(ACTIONS, OBSERVATIONS, VALID_STATES, BELIEF_STATE)

[[0.11111111 0.11111111 0.11111111 0.        ]
 [0.11111111        nan 0.11111111 0.        ]
 [0.11111111 0.11111111 0.11111111 0.11111111]]
[[0.29187958 0.03648495 0.00697268 0.        ]
 [0.33164818        nan 0.00412361 0.        ]
 [0.29187958 0.0324723  0.00263902 0.00190009]]
[[2.82110488e-01 7.58305155e-02 7.48811789e-04 0.00000000e+00]
 [3.24471225e-01            nan 9.04259987e-04 3.18836650e-04]
 [2.82110488e-01 3.13854657e-02 2.83410356e-04 1.83649910e-03]]
[[2.57123654e-01 1.06183702e-01 6.14239841e-03 0.00000000e+00]
 [2.99308158e-01            nan 4.14819098e-02 3.22885489e-05]
 [2.57123654e-01 2.86056208e-02 2.32477552e-03 1.67383838e-03]]
[[1.73721655e-01 1.00522588e-01 3.73501557e-02 0.00000000e+00]
 [2.04638836e-01            nan 2.75254631e-01 1.96337367e-04]
 [1.73721655e-01 1.93269491e-02 1.41362904e-02 1.13090323e-03]]
