In [1]:
import numpy as np
import random
import matplotlib.pyplot as plt
plt.style.use(["science", "notebook", "grid"])

In [2]:
%config Completer.use_jedi = False

In [3]:
import modified_frozen_lake.frozen_lake as frozen_lake

In [4]:
env = frozen_lake.FrozenLakeEnv(map_name="4x4", slip_rate=0.1)

In [5]:
env.action_space.n, env.observation_space.n

(4, 16)

In [6]:
initial_state = env.reset()

In [7]:
initial_state, env.render()


[41mS[0mFFF
FHFH
FFFH
HFFG


(0, None)

In [8]:
initial_state = env.reset()
done = 0
env.render()
while not done:
    a = env.action_space.sample()
    next_state, reward, done, extra = env.step(a)
    env.render()


[41mS[0mFFF
FHFH
FFFH
HFFG
  (Left)
[41mS[0mFFF
FHFH
FFFH
HFFG
  (Down)
SFFF
[41mF[0mHFH
FFFH
HFFG
  (Up)
[41mS[0mFFF
FHFH
FFFH
HFFG
  (Up)
[41mS[0mFFF
FHFH
FFFH
HFFG
  (Up)
[41mS[0mFFF
FHFH
FFFH
HFFG
  (Down)
SFFF
[41mF[0mHFH
FFFH
HFFG
  (Up)
[41mS[0mFFF
FHFH
FFFH
HFFG
  (Up)
[41mS[0mFFF
FHFH
FFFH
HFFG
  (Up)
[41mS[0mFFF
FHFH
FFFH
HFFG
  (Right)
S[41mF[0mFF
FHFH
FFFH
HFFG
  (Left)
[41mS[0mFFF
FHFH
FFFH
HFFG
  (Left)
[41mS[0mFFF
FHFH
FFFH
HFFG
  (Left)
[41mS[0mFFF
FHFH
FFFH
HFFG
  (Up)
[41mS[0mFFF
FHFH
FFFH
HFFG
  (Left)
[41mS[0mFFF
FHFH
FFFH
HFFG
  (Down)
SFFF
[41mF[0mHFH
FFFH
HFFG
  (Down)
SFFF
FHFH
[41mF[0mFFH
HFFG
  (Left)
SFFF
FHFH
[41mF[0mFFH
HFFG
  (Down)
SFFF
FHFH
FFFH
[41mH[0mFFG
  (Down)
SFFF
FHFH
FFFH
[41mH[0mFFG


In [154]:
a, next_state, reward, done, extra

(3, 12, -100, True, {'prob': 1.0})

In [155]:
env.render()

  (Down)
SFFF
FHFH
FFFH
[41mH[0mFFG


In [11]:
nS = env.observation_space.n
nA = env.action_space.n
nS, nA

(16, 4)

In [59]:
o = np.ones((20,5))
# o

It's probably not wise to use dataclasses for `Policy` because I have to implement `__eq__`, `__lt__` etc. manually anyway. Only gain seems the `__repr__` method but that is also not too hard to write.
Also, I need to write subclasses for $\epsilon$-soft policies and deterministic policies and I'm not expert in subclassing dataclasses and `__repr__` methods!

Also policies have a partial ordering among them, but we need state values to check i.e. 
$\pi_1 \geq \pi_2$ iff $V_{\pi_1}(s) \geq V_{\pi_2}(s)$ $\forall s \in S$

In [185]:
class Policy(object):
    
    def __init__(self, nS: int, nA: int, pi: np.ndarray=None):
        self.nS, self.nA = nS, nA
        if pi is None:
            # generates an equiprobable policy for all actions.
            pi = np.ones((self.nS, self.nA))
            pi /= pi.sum(axis=1)[:, None]
            self.pi = pi
        else:
            assert pi.shape == (self.nS, self.nA), f"policy matrix shape should be {self.nS} x {self.nA}"
            self.pi = pi
        assert np.allclose(self.pi.sum(), self.nS), f"{self.nS=} self.pi.sum()={self.pi.sum()}"
            
    def __repr__(self):
        return f"Policy(nS={self.nS}, nA={self.nA}, pi={self.pi.__repr__()})"
    
    def __eq__(self, other):
        if self.__class__ is not other.__class__:
            raise NotImplementedError(f"can't compare {self.__class__} object with {other.__class__} object")
        return self.nS == other.nS and self.nA == other.nA and np.allclose(self.pi, other.pi)
    
    def update(self, q: np.ndarray, state: int):
        """
        det. policy update:
        -------------------
        pi(state) = argmax\_a q(state, a)
        
        """
        raise NotImplementedError("need to implement differently for eps-soft & det. policies!")
    
    def sample_action(self, state: int) -> int:
        """
        returns the sampled action from the distribution policy pi(a|s)
        """
        assert state < self.nS, f"state must be in {{0, 1, ..., {self.nS-1}}}"
        return np.random.choice(self.nA, p=self.pi[state])

In [186]:
class EpsilonSoftPolicy(Policy):
    
    def __init__(self, epsilon: float, nS: int, nA: int, pi: np.ndarray=None):
        self.epsilon = epsilon
        super().__init__(nS, nA, pi)
        
    def __repr__(self):
        return f"EpsilonSoftPolicy(epsilon={self.epsilon}, nS={self.nS}, nA={self.nA}, pi={self.pi.__repr__()})"
    
    def update(self, q: np.ndarray, state: int):
        """
        update policy wrt q-function for given state.
        `q` is the (nS x nA) action value function.
        eps-soft update algo:
        ---------------------
                 a* <- argmax\_a q(state, a)
        pi(a|state) <- (1-eps+eps/nA) if a==a* else eps/nA
        """
        pass

In [187]:
esp = EpsilonSoftPolicy(0.01, nS, nA)
print(esp, esp.sample_action(0))

EpsilonSoftPolicy(epsilon=0.01, nS=16, nA=4, pi=array([[0.25, 0.25, 0.25, 0.25],
       [0.25, 0.25, 0.25, 0.25],
       [0.25, 0.25, 0.25, 0.25],
       [0.25, 0.25, 0.25, 0.25],
       [0.25, 0.25, 0.25, 0.25],
       [0.25, 0.25, 0.25, 0.25],
       [0.25, 0.25, 0.25, 0.25],
       [0.25, 0.25, 0.25, 0.25],
       [0.25, 0.25, 0.25, 0.25],
       [0.25, 0.25, 0.25, 0.25],
       [0.25, 0.25, 0.25, 0.25],
       [0.25, 0.25, 0.25, 0.25],
       [0.25, 0.25, 0.25, 0.25],
       [0.25, 0.25, 0.25, 0.25],
       [0.25, 0.25, 0.25, 0.25],
       [0.25, 0.25, 0.25, 0.25]])) 2


In [153]:
p = Policy(nS, nA)
p1 = Policy(nS-10, nA-2)
p, p1, p==p1

(Policy(nS=16, nA=4, pi=array([[0.25, 0.25, 0.25, 0.25],
        [0.25, 0.25, 0.25, 0.25],
        [0.25, 0.25, 0.25, 0.25],
        [0.25, 0.25, 0.25, 0.25],
        [0.25, 0.25, 0.25, 0.25],
        [0.25, 0.25, 0.25, 0.25],
        [0.25, 0.25, 0.25, 0.25],
        [0.25, 0.25, 0.25, 0.25],
        [0.25, 0.25, 0.25, 0.25],
        [0.25, 0.25, 0.25, 0.25],
        [0.25, 0.25, 0.25, 0.25],
        [0.25, 0.25, 0.25, 0.25],
        [0.25, 0.25, 0.25, 0.25],
        [0.25, 0.25, 0.25, 0.25],
        [0.25, 0.25, 0.25, 0.25],
        [0.25, 0.25, 0.25, 0.25]])),
 Policy(nS=6, nA=2, pi=array([[0.5, 0.5],
        [0.5, 0.5],
        [0.5, 0.5],
        [0.5, 0.5],
        [0.5, 0.5],
        [0.5, 0.5]])),
 False)

In [156]:
Policy(7, 8, o)

AssertionError: policy matrix shape should be 7 x 8

In [158]:
p.sample_action(15)

2

In [160]:
p.update()

NotImplementedError: need to implement differently for eps-soft & det. policies!