In [1]:
!pip install pymdptoolbox
import numpy as np
import itertools as it
import mdptoolbox
from scipy.sparse import csr_matrix
import datetime



In [2]:
warehouseSize = 4
cellStates = ["w","r","b","e"] #white, red, blue and empty
orders = ["sw", "sr", "sb", "rw", "rr", "rb"]
states = ['0','w','b','r']
actions = ['sto_r', 'sto_b', 'sto_w', 'res_r', 'res_b', 'res_w']
all_states = []
for i in states:
  for j in states:
      for k in states:
          for l in states:
              for m in actions:
                  st_ac = (i,j,k,l,m)
                  all_states.append(st_ac)
statesDict = {x:i for i,x in enumerate(states)}

negRewards = {
    5:-4,
    4:-3,
    3:-3,
    2:-2,
    1:-2,
    0:-1,
    "error": -100
}

posRewards = {
    5:2,
    4:4,
    3:4,
    2:6,
    1:6,
    0:8,
    "error": -100
}

rewards = posRewards

In [3]:
def getOptions(state, action):
    #take state and action (index of cell targeted)
    #return list of indices for states-list, that represent all the viable follow-up states
    viable = list()
    valid = bool()
    
    order = str()
    if state[-1][0] == "s":
        order = "store"
        valid = state[action] == "e"
    else:
        order = "restore"
        valid = state[action] == state[-1][-1]
    
    if valid:
        if order == "store":
          viable = [i for i,x in enumerate(states) if (x[action] == state[-1][-1]) and ([y for j,y in enumerate(x) if j!=action][:-1] == [y for j,y in enumerate(state) if j!=action][:-1])]
        else:
          viable = [i for i,x in enumerate(states) if (x[action] == "e") and ([y for j,y in enumerate(x) if j!=action][:-1] == [y for j,y in enumerate(state) if j!=action][:-1])]
      
    else:
        viable = [i for i,x in enumerate(states) if (state[:-1] == x[:-1])]
    
    return viable, valid

In [4]:
def convertStates(s):
    #convert list of indices to corresponding states
    conS = list()
    for x in s:
        conS.append(states[x])
        
    return conS

In [5]:
def testWarehouse(orderList, policy):
  #test the performance of a policy on a list of orders
  state = ["e","e","e","e",""]
  
  cost = 0
  
  for order in orderList:
    state[-1] = order
    cell = policy[statesDict[tuple(state)]]
    
    if state[-1][0] == "s":
      state[cell] = state[-1][-1]
    else:
      state[cell] = "e"
      
    if cell == 0:
      cost += 1
    elif cell in [1,2]:
      cost += 2
    else:
      cost += 3
  return cost

In [6]:
def visState(state):
  #visualize the structure of a state
  if len(state) == 5:
    print(state[-1])
    print([state[2],state[3]])
    print([state[0],state[1]])
    print("\n")
  else:
    pass

In [9]:
def createGreedyPolicy():
  pol = list()
  
  for state in states:
    order = state[-1][0]
    
    if order == "s":
      #state tries to store object
      valid = False
      for i in range(4):
        if state[i] == "e":
          #state i is empty
          valid = True
          pol.append(i)
          break
      if not valid:
        #there is no empty cell to store an object
        pol.append(3)
        
    else:
      #state tries to restore object
      valid = False
      for i in range(4):
        if state[i] == state[-1][-1]:
          #cell i contains matching object
          valid = True
          pol.append(i)
          break
      if not valid:
        #there is no matching object to restore
        pol.append(3)
  return pol

In [15]:
print(datetime.datetime.now())

trans_prob_all = []

#probability for every next state, assuming an equal probability distribution
prob = 1/6 
#explicit calculation of the last probability to avoid floating-point errors
problast = 1 - (1/6)*5

#create on mat
for i in range(warehouseSize):
    trans_prob = np.zeros((1536, 1536))    
    for x in range(len(trans_prob)):
        opts, valid = getOptions(states[x], i)

        if len(opts)!= 6:
          print("ERROR: Wrong amount of options: ", str(len(opts)))
          
        for j,k in enumerate(opts):
            if j == (len(opts)-1):
              trans_prob[x][k] = problast
            else:
              trans_prob[x][k] = prob


    trans_prob_all.append(csr_matrix(trans_prob))
    print("Done: ", str(i))


print("Trans Matrices created!")



print(datetime.datetime.now())

2021-07-14 12:52:42.532757
ERROR: Wrong amount of options:  0
ERROR: Wrong amount of options:  0
ERROR: Wrong amount of options:  0
ERROR: Wrong amount of options:  0


IndexError: list index out of range

In [12]:
print(datetime.datetime.now())
reward_all = []
for state in states:
    reward = np.zeros(4, dtype=np.float16)
    
    order = str()
    if state[-1][0] == "s":
        order = "store"
    else:
        order = "restore"
    
    if order == "store":
        for i,x in enumerate(reward):
            if state[i]=="e":
                reward[i] = rewards[i]
            else:
                reward[i] = rewards["error"]
    else:
        for i,x in enumerate(reward):
            if state[i] == state[-1][-1]:
                reward[i] = rewards[i]
            else:
                reward[i] = rewards["error"]
    reward_all.append(reward)
reward_all = np.array(reward_all)
print("Done Rewards!")
print(datetime.datetime.now())

2021-07-14 12:36:17.352607


IndexError: string index out of range

In [None]:
mdpresultPolicy = mdptoolbox.mdp.PolicyIteration(trans_prob_all, reward_all,0.999, max_iter=10000)
mdpresultValue = mdptoolbox.mdp.ValueIteration(trans_prob_all, reward_all,0.999, max_iter=10000)

In [None]:
mdpresultPolicy.run()
mdpresultValue.run()

print("MDP trained")

In [14]:
'''The Values'''

print('PolicyIteration:')
print(mdpresultPolicy.policy)
print(mdpresultPolicy.V)
print(mdpresultPolicy.iter)


print('ValueIteration:')
print(mdpresultValue.policy)
print(mdpresultValue.V)
print(mdpresultValue.iter)

PolicyIteration:


NameError: name 'mdpresultPolicy' is not defined

In [None]:
print("mdpValue Policy: ",testWarehouse(orderList, mdpresultValue.policy))
print("mdpPolicy Policy: ",testWarehouse(orderList, mdpresultPolicy.policy))
print("Greedy Policy: ",testWarehouse(orderList, createGreedyPolicy()))