In [52]:
import numpy as np
import gymnasium as gym

env = gym.make('FrozenLake-v1', map_name = '4x4')
states = env.observation_space.n
actions = env.action_space.n
n = int(np.sqrt(states))

# Functions

## `identify_states`

Consider state 9 in frozen lake. The actions that lead to this state are:
- left (0) from state 10
- down (1) from state 5
- right (2) from state 8
- up from (3) state 13

Consider state 4 in frozen lake. The actions that lead to this state are:
- left (0) from state 5
- down (1) from state 0
- up (3) from state 8

Note that inner states will be accesible by 4 states and actions, while states on an edge are accessible from 3 states and actions. States on the corners are accessible from 2 states and actions. \
The `identify_states` function groups states that are along an edge into a list. 

In [53]:
def identify_states(n):
    top_row = []
    for i in range(0, n):
        top_row.append(i)

    bottom_row = []
    for i in range(0, n):
        state = n ** 2 - n + i
        bottom_row.append(state)

    left_col = []
    for i in range(0, n):
        state = n * i
        left_col.append(state)

    right_col = []
    for i in range(0, n):
        state = n * i + (n - 1)
        right_col.append(state)

    return top_row, bottom_row, left_col, right_col

In [93]:
top_row, bottom_row, left_col, right_col = identify_states(n)
print(f"States along the top row are {top_row}")
print(f"States along the bottom row are {bottom_row}")
print(f"States along the left column are {left_col}")
print(f"States along the right column are {right_col}")

States along the top row are [0, 1, 2, 3]
States along the bottom row are [12, 13, 14, 15]
States along the left column are [0, 4, 8, 12]
States along the right column are [3, 7, 11, 15]


## `get_pairs`

This function will return a list of the states and actions that lead to a particular state. It uses the lists created by the `identify_states` function, the co-ordinates of corner states, and the co-ordinates of inner states to determine the states and actions. 

In [62]:
def get_pairs(state, n, left_col, right_col, top_row, bottom_row):
    list = []
    
    if state == 0: # top left corner
        list.append((state + 1, 0)) # left action to state
        list.append((state + n, 3)) # up action to state
    elif state == (n - 1): #top right corner
        list.append((state - 1, 2)) # right action to state
        list.append((state + n, 3)) # up action to state
    elif state == (n * (n - 1)): #bottom left corner
        list.append((state + 1, 0)) # left action to state
        list.append((state - n, 1)) # down action to state 
    elif state == ((n ** 2) - 1): #bottom right corner
        list.append((state - 1, 2)) # right action to state
        list.append((state - n, 1)) # down action to state 
    elif state in left_col:
        list.append((state + 1, 0)) # left action to state
        list.append((state - n, 1)) # down action to state 
        list.append((state + n, 3)) # up action to state
    elif state in right_col:
        list.append((state - n, 1)) # down action to state 
        list.append((state - 1, 2)) # right action to state
        list.append((state + n, 3)) # up action to state
    elif state in top_row:
        list.append((state + 1, 0)) # left action to state
        list.append((state - 1, 2)) # right action to state
        list.append((state + n, 3)) # up action to state
    elif state in bottom_row:
        list.append((state + 1, 0)) # left action to state
        list.append((state - n, 1)) # down action to state 
        list.append((state - 1, 2)) # right action to state
    elif state in range(0, n*n):
        list.append((state + 1, 0)) # left action to state
        list.append((state - n, 1)) # down action to state 
        list.append((state - 1, 2)) # right action to state
        list.append((state + n, 3)) # up action to state
    else:
        raise Exception("State does not exist in the environment.")

    return list

In [64]:
state_nine = get_pairs(9, n, left_col, right_col, top_row, bottom_row)
print(f"The state action pairs that lead to state 9 are {state_nine}")

state_four = get_pairs(4, n, left_col, right_col, top_row, bottom_row)
print(f"The state action pairs that lead to state 4 are {state_four}")

The state action pairs that lead to state 9 are [(10, 0), (5, 1), (8, 2), (13, 3)]
The state action pairs that lead to state 4 are [(5, 0), (0, 1), (8, 3)]


# Main

In [66]:
n = int(np.sqrt(states))
top_row, bottom_row, left_col, right_col = identify_states(n)

# make template matrix
opinion = np.zeros((states, actions), dtype = "f, f, f, f")
for state in range(states):
    for action in range(actions):
        opinion[state, action] = (0, 0, 1, 1/actions)

#read file
file = open('opinion.txt', 'r')
content = file.read()
file.close()
lines = content.splitlines()

# u is first line 
# b + d + u = 1, so based on user's uncertainty, we will split the remaining value (1 - u) between b and d based on their rating on belief/disbelief scale
u = float(lines[0])

for i in range(1, len(lines)):

    a = lines[i].split(',')
    state = int(a[0])

    # user is unsure, split remaining value evenly between b and d
    if (a[1] == ' +' or a[1]== ' -') and (int(a[2]) == 0):
        b = round((1 - u) /2, 2)
        d = round(b, 2)

    # user has high belief, set 100% of remaining value to b
    elif a[1] == ' +' and int(a[2]) == 2: 
        b = round(1 - u, 2)
        d = round(1 - (u + b), 2)

    # user has some belief, set b to 75% of remaining value 
    elif a[1] == ' +' and int(a[2]) == 1: 
        b = round(0.75 * (1 - u), 2)
        d = round(1 - (u + b), 2)

    # user has high disbelief, set d to 100% of remaining value 
    elif a[1] == ' -' and int(a[2]) == 2: 
        d = round(1 - u, 2)
        b = round(1 - (u + d), 2)

    # user has some disbelief, set d to 75% of remaining value 
    elif a[1] == ' -' and int(a[2]) == 1: 
        d = round(0.75 * (1 - u), 2)
        b = round(1 - (u + d), 2)

    states_actions = get_pairs(state, n, left_col, right_col, top_row, bottom_row)

    for pair in states_actions:
        curr_state = pair[0]
        curr_action = pair[1]
        opinion[curr_state, curr_action]= (b, d, u, 1/actions)

opinion = np.save('opinion', opinion)