In [1]:
import gymnasium as gym
#from lib.env.tilemap import TilemapEnv, TilemapConstraints, Ordinal
from lib.env.tilemap import TilemapConstraints, Ordinal
from lib.env.tilemapv2 import TilemapEnv
from lib.env.utils import inbounds

from lib.mdp.frozen_lake import FrozenLakeMDP
#from lib.mdp.frozen_lakev2 import FrozenLakeMDP

#from lib.mdp.tilemap import TilemapMDP
#from lib.mdp.tilemapv2 import TilemapMDP
from lib.mdp.tilemapv3 import TilemapMDP

from lib.mdp.value_iteration import ValueIteration
#from lib.mdp.value_iterationv2 import ValueIteration

#from lib.mdp.policy_iterationv2 import PolicyIteration
from lib.mdp.policy_iteration import PolicyIteration

import numpy as np
import matplotlib.pyplot as plt

import time

In [2]:
tilemap_constraints = TilemapConstraints.from_rules([
    ('sand', 'NESW', 'sand'),
    ('water', 'NESW', 'water'),
    
    ('sand_water_north', 'N', 'water'),
    ('sand_water_north', 'S', 'sand'),
    ('sand_water_north', 'EW', 'sand_water_north'),
    ('sand_water_north', 'E', 'sand_water_northeast'),
    ('sand_water_north', 'W', 'sand_water_northwest'),
    ('sand_water_north', 'S', 'sand_water_south'),
    
    ('sand_water_east', 'E', 'water'),
    ('sand_water_east', 'W', 'sand'),
    ('sand_water_east', 'NS', 'sand_water_east'),
    ('sand_water_east', 'N', 'sand_water_northeast'),
    ('sand_water_east', 'S', 'sand_water_southeast'),
    ('sand_water_east', 'W', 'sand_water_west'),
    
    # Top Right
    ('sand_water_northeast', 'NE', 'water'),
    
    ('sand_water_northeast', 'NE', 'sand_water_southwest'),
    ('sand_water_northeast', 'E', 'sand_water_west'),
    ('sand_water_northeast', 'E', 'sand_water_northwest'),
    
    ('sand_water_northeast', 'S', 'sand_water_southeast'),
    ('sand_water_northeast', 'W', 'sand_water_northwest'),
    #('sand_water_northeast', 'SW', 'sand'),
    
    ('sand_water_south', 'S', 'water'),
    ('sand_water_south', 'N', 'sand'),
    ('sand_water_south', 'EW', 'sand_water_south'),
    ('sand_water_south', 'E', 'sand_water_southeast'),
    ('sand_water_south', 'W', 'sand_water_southwest'),
    ('sand_water_south', 'N', 'sand_water_north'),
    
    # Bottom Right
    ('sand_water_southeast', 'SE', 'water'),
    
    ('sand_water_southeast', 'SE', 'sand_water_northwest'),
    ('sand_water_southeast', 'E', 'sand_water_west'),
    ('sand_water_southeast', 'E', 'sand_water_southwest'),
    
    ('sand_water_southeast', 'N', 'sand_water_northeast'),
    ('sand_water_southeast', 'S', 'sand_water_northeast'),
    ('sand_water_southeast', 'W', 'sand_water_southwest'),
    #('sand_water_southeast', 'NW', 'sand'),
    
    ('sand_water_west', 'W', 'water'),
    ('sand_water_west', 'E', 'sand'),
    ('sand_water_west', 'NS', 'sand_water_west'),
    ('sand_water_west', 'N', 'sand_water_northwest'),
    ('sand_water_west', 'S', 'sand_water_southwest'),
    ('sand_water_west', 'E', 'sand_water_east'),
    
    # Bottom Left
    ('sand_water_southwest', 'SW', 'water'),
    
    ('sand_water_southwest', 'SW', 'sand_water_northeast'),
    ('sand_water_southwest', 'W', 'sand_water_east'),
    ('sand_water_southwest', 'W', 'sand_water_southeast'),
    
    ('sand_water_southwest', 'N', 'sand_water_northwest'),
    ('sand_water_southwest', 'E', 'sand_water_southeast'),
    #('sand_water_southwest', 'NE', 'sand'),
    
    # Top Left
    ('sand_water_northwest', 'NW', 'water'),
    
    ('sand_water_northwest', 'W', 'sand_water_east'),
    ('sand_water_northwest', 'W', 'sand_water_northeast'),
    ('sand_water_northwest', 'W', 'sand_water_southeast'),
    
    ('sand_water_northwest', 'NS', 'sand_water_southwest'),
    ('sand_water_northwest', 'E', 'sand_water_northeast'),
    #('sand_water_northwest', 'SE', 'sand')
])

In [3]:
grass = (0, 1.0, 0)
sand = (0.76, 0.7, 0.5)
water = (0, 0, 1.0)

tilemap_images = {
    'grass': np.full((16, 16, 3), fill_value=grass),
    'sand': np.full((16, 16, 3), fill_value=sand),
    'water': np.full((16, 16, 3), fill_value=water),
}

tilemap_images['sand_water_north'] = np.full((16, 16, 3), fill_value=sand)
tilemap_images['sand_water_north'][:8, :, :] = water
tilemap_images['sand_water_north'][8, :, :] = (1.0, 0, 0)

tilemap_images['sand_water_east'] = np.full((16, 16, 3), fill_value=sand)
tilemap_images['sand_water_east'][:, 8:, :] = water
tilemap_images['sand_water_east'][:, 8, :] = (1.0, 0, 0)

tilemap_images['sand_water_south'] = np.full((16, 16, 3), fill_value=sand)
tilemap_images['sand_water_south'][8:, :, :] = water
tilemap_images['sand_water_south'][8, :, :] = (1.0, 0, 0)

tilemap_images['sand_water_west'] = np.full((16, 16, 3), fill_value=sand)
tilemap_images['sand_water_west'][:, :8, :] = water
tilemap_images['sand_water_west'][:, 8, :] = (1.0, 0, 0)

tilemap_images['sand_water_northeast'] = np.full((16, 16, 3), fill_value=water)
tilemap_images['sand_water_northeast'][8:, :8, :] = sand
tilemap_images['sand_water_northeast'][8, :8, :] = (1.0, 0, 0)
tilemap_images['sand_water_northeast'][8:, 8, :] = (1.0, 0, 0)

tilemap_images['sand_water_southeast'] = np.full((16, 16, 3), fill_value=water)
tilemap_images['sand_water_southeast'][:8, :8, :] = sand
tilemap_images['sand_water_southeast'][8, :8, :] = (1.0, 0, 0)
tilemap_images['sand_water_southeast'][:8, 8, :] = (1.0, 0, 0)

tilemap_images['sand_water_southwest'] = np.full((16, 16, 3), fill_value=water)
tilemap_images['sand_water_southwest'][:8, 8:, :] = sand
tilemap_images['sand_water_southwest'][:8, 8, :]= (1.0, 0, 0)
tilemap_images['sand_water_southwest'][8, 8:, :] = (1.0, 0, 0)

tilemap_images['sand_water_northwest'] = np.full((16, 16, 3), fill_value=water)
tilemap_images['sand_water_northwest'][8:, 8:, :] = sand
tilemap_images['sand_water_northwest'][8:, 8, :] = (1.0, 0, 0)
tilemap_images['sand_water_northwest'][8, 8:, :] = (1.0, 0, 0)

for k in tilemap_images:
    tilemap_images[k][15, :, :] = (0, 0, 0)
    tilemap_images[k][:, 15, :] = (0, 0, 0)


In [4]:
mdp = FrozenLakeMDP()
# st = time.time()
# %prun mdp = TilemapMDP((2, 2), tilemap_constraints, tilemap_images)
# et = time.time()
# print(f"mdp setup time: {et - st}")

In [5]:
# mdp._env.reset()
# plt.imshow(mdp._env.render())

In [6]:
value_iteration = ValueIteration(mdp)
st = time.time()
value_iteration.run()
et = time.time()
print(f"Value Iteration completed in {et - st} seconds")
#print(value_iteration.policy)
print(np.mean(value_iteration.evaluate(100)[0]))

Value Iteration completed in 0.0005435943603515625 seconds
0.0


In [7]:
policy_iteration = PolicyIteration(mdp)
st = time.time()
policy_iteration.run()
et = time.time()
print(f"Policy Iteration completed in {et - st} seconds")
#print(policy_iteration.policy)
print(np.mean(policy_iteration.evaluate(100)[0]))

Policy Iteration completed in 0.04411578178405762 seconds
0.74
