# CPSC 422 - Assignment 1

## Question 3

In [1]:
import numpy as np
import copy
from sympy import Matrix, init_printing # https://stackoverflow.com/questions/13214809/pretty-print-2d-python-list
init_printing()

In [2]:
class StateSpace:
    
    numRows = 3
    numCols = 4
    invalidSpaces = {(2,2)}
    uniformProbability = 1/9

    def __init__(self, actions, observations, startingState=None):
        self.states = [[0, 0, 0, 0], [0, None, 0, 0], [0, 0, 0, 0]]
        self.setInitialBeliefState(startingState)
        self.priorStates = copy.deepcopy(self.states)
        self.actions = actions
        self.observations = observations
        assert len(self.actions) == len(self.observations), "Actions and observations must be of same length"
    
    def computeBeliefStates(self):
        for i in range(len(self.actions)):
            self.updateBeliefStates(self.actions[i], self.observations[i])
    
    def updateBeliefStates(self, action, observation):
        copyStates = copy.deepcopy(self.states)
        for c in range(1,self.numCols+1):
            for r in range(1,self.numRows+1):
                state = (c,r)
                if not self.stateIsValid(state):
                    continue
                neighbors = [(c,r+1), (c-1,r), (c,r-1), (c+1,r)]
                sumOfPriorStates = 0
                for neighbor in neighbors:
                    if self.stateIsValid(neighbor):
                        sumOfPriorStates += (self.getTransitionProbability(state, neighbor, action)*self.getPriorState(neighbor))
                self.setState(state, self.getObservationProbability(observation,state)*sumOfPriorStates)
        self.priorStates = copyStates
        
    def setInitialBeliefState(self, startingState):
        if startingState == None:
            for c in range(1,self.numCols+1):
                for r in range(1, self.numRows+1):
                    state = (c,r)
                    if self.stateIsValid(state) and not self.stateIsTerminal(state):
                        self.setState(state, self.uniformProbability)
        elif not self.stateIsValid(startingState):
            print("Could not set starting state:")
            return self.invalidStateError(startingState)
        else:
            for c in range(1,self.numCols+1):
                for r in range(1,self.numRows+1):
                    state = (c,r)
                    if self.stateIsValid(state):
                        self.setState(state, 0)
            self.setState(startingState, 1)

    def stateIsValid(self, state):
        col = state[0]
        row = state[1]
        return col >= 1 and col <= self.numCols and row >= 1 and row <= self.numRows and state not in self.invalidSpaces

    def invalidStateError(self, state):
        print(f"State ({state[0]}, {state[1]}) is invalid!")
        return None
    
    def stateIsTerminal(self, state):
        if not self.stateIsValid(state):
            print("Could not determine if state is terminal:")
            return self.invalidStateError(state)
        return (state[0] == 4 and (state[1] == 2 or state[1] == 3))

    def getState(self, state):
        if not self.stateIsValid(state):
            print("Could not get state:")
            return self.invalidStateError(state)
        col = state[0]
        row = state[1]
        return (self.states[abs(row-self.numRows)][col-1])
    
    def getPriorState(self, state):
        if not self.stateIsValid(state):
            print("Could not get state from priorStates:")
            return self.invalidStateError(state)
        col = state[0]
        row = state[1]
        return (self.priorStates[abs(row-self.numRows)][col-1])
    
    def getObservationProbability(self, observation, state):
        if not self.stateIsValid(state):
            print("Could not get observation probability:")
            return self.invalidStateError(state)
        if observation == "end" and self.stateIsTerminal(state): return 1
        elif observation == "end" or self.stateIsTerminal(state): return 0
        elif observation == 1:
            if state[0] == 3: return 0.9
            else: return 0.1
        elif observation == 2:
            if state[0] == 3: return 0.1
            else: return 0.9
    
    def getTransitionProbability(self, state, priorState, action):
        if not self.stateIsValid(state):
            print("Could not get transition probability - state error:")
            return self.invalidStateError(state)
        elif not self.stateIsValid(priorState):
            print("Could not get transition probability - priorState error:")
            return self.invalidStateError(priorState)
        elif self.stateIsTerminal(priorState): return 0
        if state == priorState:
            sameStateTransitionTable = {
                ((1,1),"up"): 0.1, ((1,1),"left"): 0.9, ((1,1),"down"): 0.9, ((1,1),"right"): 0.1,  # State (1,1)
                ((2,1),"up"): 0.8, ((2,1),"left"): 0.2, ((2,1),"down"): 0.8, ((2,1),"right"): 0.2,  # State (2,1)
                ((3,1),"up"): 0, ((3,1),"left"): 0.1, ((3,1),"down"): 0.8, ((3,1),"right"): 0.1,    # State (3,1)
                ((4,1),"up"): 0.1, ((4,1),"left"): 0.1, ((4,1),"down"): 0.9, ((4,1),"right"): 0.9,  # State (4,1)
                ((1,2),"up"): 0.2, ((1,2),"left"): 0.8, ((1,2),"down"): 0.2, ((1,2),"right"): 0.8,  # State (1,2)
                ((3,2),"up"): 0.1, ((3,2),"left"): 0.8, ((3,2),"down"): 0.1, ((3,2),"right"): 0,    # State (3,2)
                ((1,3),"up"): 0.9, ((1,3),"left"): 0.9, ((1,3),"down"): 0.1, ((1,3),"right"): 0.1,  # State (1,3)
                ((2,3),"up"): 0.8, ((2,3),"left"): 0.2, ((2,3),"down"): 0.8, ((2,3),"right"): 0.2,  # State (2,3)
                ((3,3),"up"): 0.8, ((3,3),"left"): 0.1, ((3,3),"down"): 0, ((3,3),"right"): 0.1,    # State (3,3)
            }
            return sameStateTransitionTable[state,action]
        elif state[0] == priorState[0] and state[1] == priorState[1]+1:
            # state is above priorState
            if action == "up": return 0.8
            elif action == "left": return 0.1
            elif action == "down": return 0
            elif action == "right": return 0.1
        elif state[0] == priorState[0] and state[1] == priorState[1]-1:
            # state is below priorState
            if action == "up": return 0
            elif action == "left": return 0.1
            elif action == "down": return 0.8
            elif action == "right": return 0.1
        elif state[1] == priorState[1] and state[0] == priorState[0]+1:
            # state is to the right of priorState
            if action == "up": return 0.1
            elif action == "left": return 0
            elif action == "down": return 0.1
            elif action == "right": return 0.8
        elif state[1] == priorState[1] and state[0] == priorState[0]-1:
            # state is to the left of priorState
            if action == "up": return 0.1
            elif action == "left": return 0.8
            elif action == "down": return 0.1
            elif action == "right": return 0
        else:
            return 0    
    
    def setState(self, state, val):
        if not self.stateIsValid(state):
            print("Could not set state:")
            return self.invalidStateError(state)
        col = state[0]
        row = state[1]
        self.states[abs(row-self.numRows)][col-1] = val
    
    def printStateSpace(self):
        display(Matrix(self.states))
        
#     def stateUp(self, state):
#         newState = (self.state[0], self.state[1]+1)
#         if not self.stateIsValid(newState):
#             print("Could not find state up - invalid state")
#             return none
#         return newState
    
#     def stateDown(self, state):
#         newState = (self.state[0], self.state[1]-1)
#         if not self.stateIsValid(newState):
#             print("Could not find state down - invalid state")
#             return none
#         return newState
    
#     def stateLeft(self, state):
#         newState = (self.state[0]-1, self.state[1])
#         if not self.stateIsValid(newState):
#             print("Could not find state left - invalid state")
#             return none
#         return newState
    
#     def stateRight(self, state):
#         newState = (self.state[0]+1, self.state[1])
#         if not self.stateIsValid(newState):
#             print("Could not find state right - invalid state")
#             return none
#         return newState

In [3]:
ss1 = StateSpace(["up","up","up"], [2,2,2])
ss2 = StateSpace(["up","up","up"], [1,1,1])
ss3 = StateSpace(["right","right","up"], [1,1,"end"], (2,3))
ss4 = StateSpace(["up","right","right","right"], [2,2,1,1], (1,1))
stateSpaces = [ss1, ss2, ss3, ss4]

In [4]:
for ss in stateSpaces:
    ss.computeBeliefStates()
    ss.printStateSpace()

⎡0.0594  0.009   0.000911111111111111   0.0  ⎤
⎢                                            ⎥
⎢0.0072   None   0.000177777777777778   0.0  ⎥
⎢                                            ⎥
⎣0.0018  0.0011         0.0003         0.0002⎦

⎡0.000733333333333333         0.001          0.0578   0.0  ⎤
⎢                                                          ⎥
⎢8.88888888888889e-5           None          0.0144   0.0  ⎥
⎢                                                          ⎥
⎣2.22222222222222e-5   0.000211111111111111  0.0003  0.0002⎦

⎡0.0  0.0   0.0  0.072⎤
⎢                     ⎥
⎢0.0  None  0.0   0.0 ⎥
⎢                     ⎥
⎣0.0  0.0   0.0   0.0 ⎦

⎡0.0009  0.0    0.0    0.0⎤
⎢                         ⎥
⎢ 0.0    None   0.0    0.0⎥
⎢                         ⎥
⎣0.0009  0.0   0.5184  0.0⎦

In [5]:
ss = StateSpace([], [], (2,3))

In [6]:
ss.printStateSpace()


⎡0   1    0  0⎤
⎢             ⎥
⎢0  None  0  0⎥
⎢             ⎥
⎣0   0    0  0⎦

In [7]:
for c in range(1,ss.numCols+1):
    for r in range(1,ss.numRows+1):
        state = (c,r)
        newValue = ss.getTransitionProbability(state,state,"right")
        ss.setState(state, newValue)
ss.printStateSpace()

Could not get transition probability - state error:
State (2, 2) is invalid!
Could not set state:
State (2, 2) is invalid!


⎡0.1  0.2   0.1   0 ⎤
⎢                   ⎥
⎢0.8  None   0    0 ⎥
⎢                   ⎥
⎣0.1  0.2   0.1  0.9⎦