In [None]:
import numpy as np
import itertools

In [None]:
class TableTransition(object):
    def __init__(self, transitionTable):
        self.transitionTable = transitionTable
        
    def __call__(self, currentState, action):
        assert currentState in self.transitionTable, "Current state not in set of possible states"
        assert action in self.transitionTable[currentState], "Action not valid"
        
        nextStates, nextStateDistribution = self.transitionTable[currentState][action]
        nextStateIndex = np.random.choice(np.arange(len(nextStates)), p=nextStateDistribution)
        
        nextState = nextStates[nextStateIndex]
        return(nextState)

In [24]:
def getNextState(currentStateTuple, actionTuple):
    assert len(currentStateTuple) == len(actionTuple),"Tuples are of different length."
    return(tuple([currentStateTuple[i] + actionTuple[i] for i in range(len(currentStateTuple))]))

def validCoordinate(coordinate, gridWidth, gridHeight):
    xCoordinate, yCoordinate = coordinate
    if(xCoordinate < 0 or  xCoordinate >= gridWidth):
        return(False)
    if(yCoordinate < 0 or yCoordinate >= gridHeight):
        return(False)
    return(True)

In [None]:
gridWidth = 5
gridHeight = 10

gridCoordinates = list(itertools.product(range(gridWidth), range(gridHeight)))
allActions = list(set(itertools.permutations([-1,-1, 0, 0, 1, 1], 2)))
transition = {}

In [None]:
for xyCoordinate in gridCoordinates:
    transition[xyCoordinate] = {}
    for actionIndex, actionTuple in enumerate(allActions):
        nextStates = [elementwiseTupleAddition(xyCoordinate, action) for action in allActions]
        transitionPDF = np.zeros(len(allActions))

        if validCoordinate(nextStates[actionIndex], gridWidth=gridWidth, gridHeight = gridHeight):
            transitionPDF[actionIndex] = 1
        else:
            transitionPDF[0] = 1

        transition[xyCoordinate][actionTuple] = (nextStates, transitionPDF)

In [None]:
def elementwiseTupleAddition(firstTuple, secondTuple):
    assert len(firstTuple) == len(secondTuple),"Tuples are of different length."
    return(tuple([firstTuple[i] + secondTuple[i] for i in range(len(firstTuple))]))

def getNextState(currentStateTuple, actionTuple):
    assert len(currentStateTuple) == len(actionTuple),"Tuples are of different length."
    return(tuple([currentStateTuple[i] + actionTuple[i] for i in range(len(currentStateTuple))]))

In [None]:
def takeActionFromState(coordinateTuple, actions):
    nextStates = [elementwiseTupleAddition(coordinateTuple, action) for action in actions]

In [None]:
transition = {xyCoordinate: getCoordinateTransitionDistribution(xyCoordinate, allActions) for xyCoordinate in gridCoordinates}

In [None]:
def getCoordinateTransitionDistribution(coordinateTuple, actions):
    return({action: getStateActionTransition(coordinateTuple, action) for action in actions})


In [25]:
def getStateActionTransition(currentState, action, noChangeInState = (0,0)):
    stateActionTransitionDistribution = {state: 0 for state in allStates}
    nextState = elementwiseTupleAddition(currentState, action)
    
    if validCoordinate(nextState, gridWidth, gridHeight):
        stateActionTransitionDistribution[nextState] = 1
    else:
        stateActionTransitionDistribution[noChangeInState] = 1
    return(stateActionTransitionDistribution)

In [None]:
getGridTransition = TableTransition(transition)
getGridTransition((0,2), (-1,-1))

In [None]:
class Transition(object):
    def __init__(self, transition):
        self.transitionTable = transition
        
    def transitionFunction(self, currentState, action):
        assert currentState in self.transitionTable, "Current state not in set of possible states"
        assert action in self.transitionTable[currentState], "Action not valid"
        
        resultingStates,transitionProbabilityDistributionOfState = self.transitionTable[currentState][action]
        resultingStateIndex = np.random.choice(np.arange(len(resultingStates)), p=transitionProbabilityDistributionOfState)
        
        return(resultingStates[resultingStateIndex])


In [None]:
def elementwiseTupleAddition(firstTuple, secondTuple):
    assert len(firstTuple) == len(secondTuple),"Tuples are of different length."
    return(tuple([firstTuple[i] + secondTuple[i] for i in range(len(firstTuple))]))

def validCoordinate(coordinate, gridWidth, gridHeight):
    xCoordinate, yCoordinate = coordinate
    if(xCoordinate < 0 or  xCoordinate >= gridWidth):
        return(False)
    if(yCoordinate < 0 or yCoordinate >= gridHeight):
        return(False)
    else:
        return(True)
    

In [None]:
#example world
gridWidth = 5
gridHeight = 10
gridCoordinates = [(x, y) for x in range(gridWidth) for y in range(gridHeight)]

In [None]:
allActions = [(0,0), (0,1), (1,1), (1,0), (1,-1), (0,-1), (-1,-1), (-1, 0), (-1, 1)]
transition = {}

for xyCoordinate in gridCoordinates:
    transition[xyCoordinate] = {}
    for actionIndex, actionTuple in enumerate(allActions):
        resultingStates = [elementwiseTupleAddition(xyCoordinate, action) for action in allActions]
        
        transitionPDF = np.zeros(len(allActions))
        
        if validCoordinate(resultingStates[actionIndex], gridWidth=gridWidth, gridHeight = gridHeight):
            transitionPDF[actionIndex] = 1
        else:
            transitionPDF[0] = 1
            
        transition[xyCoordinate][actionTuple] = (resultingStates, transitionPDF)


In [None]:
exampleGame = Transition(transition)

In [None]:
exampleGame.transitionFunction((0,9), (-1,-1))