In [1]:
# %load_ext autoreload
# %autoreload 2
from typing import Tuple, List, Dict, Any
from modules.loggers import TabQLogger
from modules.algorithm import AbstractAlgorithm

import gymnasium as gym
import numpy as np
import matplotlib.pyplot as plt

import pickle
import os
from collections import defaultdict

no_render_env = gym.make(
    "CliffWalking-v0",
)

render_env = gym.make(
    "CliffWalking-v0",
    render_mode = 'human'
)

In [2]:
class TabularQLearn(AbstractAlgorithm):
    def __init__(self, 
                 env: gym.Env,
                 *args,
                 alpha: float = 0.1, 
                 gamma: float = 0.99, 
                 epsilon0: float = 0.9,
                 epsilon_decay: float = 0.999,
                 q_init: float = 0.6, 
                 **kwargs):
        super().__init__(env, *args, **kwargs)
        self.logger = TabQLogger()
        self._alpha = alpha
        self._gamma = gamma
        self._epsilon = epsilon0
        self._epsilon_decay = epsilon_decay
        self._q_init = q_init
        
        self._q_table = defaultdict(lambda: np.full(self._n_actions, self._q_init))

        self.hyperparams = {
            'alpha': alpha,
            'gamma': gamma,
            'q_init': q_init,
            'epsilon0': epsilon0,
            'epsilon_decay': epsilon_decay
        }

    def _encode_observation(self, observation):
        if self.is_discrete:
            return str(observation)
        
        raise NotImplementedError
        #return np.minimum(self.n // 2, np.maximum(-self.n // 2, (observation / self.state_step_size).astype('int64'))).__str__()

    def get_action(self, observation):
        """Follow e-greedy policy to get action for a given state."""
        qvals = self._q_table[self._encode_observation(observation)]

        # e-greedy policy
        if np.random.rand() < self._epsilon:
            return np.random.choice(self._n_actions)
        
        return np.argmax(qvals)

    def _update(self, observation, action, reward, next_observation):
        """
        Updates the Q values using Bellman's equation.
        """
        observation = self._encode_observation(observation)
        next_observation = self._encode_observation(next_observation)

        # Bellman's equation
        self._q_table[observation][action] = (1 - self._alpha) * self._q_table[observation][action] + self._alpha * (
            reward + self._gamma * np.max(self._q_table[next_observation]))
        
    def train(self, n_episodes):
        CHECKPOINT_INTERVAL = 1000
        self.logger.write_hyperparameters(self.hyperparams)
        
        for episode in range(n_episodes):
            observation, info = self._env.reset()
            total_reward = 0
            win = False
            done = False
            
            while not done:
                action = self.get_action(observation)
                next_observation, reward, terminated, truncated, info = self._env.step(action)
                self._update(observation, action, reward, next_observation)
                observation = next_observation
                total_reward += reward

                done = terminated or truncated
                if terminated:
                    self._set_terminal_observation(observation)

            self.logger.log('total_reward', total_reward)
            #self.logger.log('win', win)

            if episode % CHECKPOINT_INTERVAL == 0 and episode > 0: 
                self.logger.save_checkpoint(self._q_table, f'qvalues.pkl')
                print(f'Episode {episode}/{n_episodes} | epsilon: {self._epsilon} | mean total reward: {np.mean(self.logger.metrics["total_reward"][-50:])}')
            
            if self._epsilon > 0.001:
                self._epsilon *= self._epsilon_decay
        
    def load(self, path):
        with open(path, 'rb') as f:
            self._q_table = pickle.load(f)    

    def _set_terminal_observation(self, terminal_observation):
        # sets all the QValues for a terminal observation to zero 
        self._q_table[self._encode_observation(terminal_observation)] = np.zeros(self._n_actions)
    
    def render(self, env: gym.Env, n_episodes: int = 1, state_dict_path: str = None):
        assert env.render_mode == 'human', 'Render mode must be set to human'

        if os.path.exists(state_dict_path):
            self.load(state_dict_path)

        for i in range(n_episodes):
            observation, _ = env.reset()
            done = False
            while not done:
                action = self.get_action(observation)
                observation, _, terminated, truncated, _ = env.step(action)
                done = terminated or truncated

        env.close()

In [3]:
env = gym.make("CliffWalking-v0")
tabq = TabularQLearn(env)
tabq.train(10000)

Episode 1000/10000 | epsilon: 0.33092588229386716 | mean total reward: -148.48
Episode 2000/10000 | epsilon: 0.12167993285774943 | mean total reward: -56.82
Episode 3000/10000 | epsilon: 0.04474115459823254 | mean total reward: -28.32
Episode 4000/10000 | epsilon: 0.0164511178447405 | mean total reward: -19.52
Episode 5000/10000 | epsilon: 0.006049000763879044 | mean total reward: -15.16
Episode 6000/10000 | epsilon: 0.0022241899053143913 | mean total reward: -13.04
Episode 7000/10000 | epsilon: 0.0009999929953143968 | mean total reward: -13.04
Episode 8000/10000 | epsilon: 0.0009999929953143968 | mean total reward: -13.0
- Current mean reward: -13.02, but best mean reward is: -13.0
- Mean reward did not improve. Checkpoint not saved.
Episode 9000/10000 | epsilon: 0.0009999929953143968 | mean total reward: -13.02


In [4]:
#tabq = TabularQLearn(gym.make("CliffWalking-v0", render_mode = 'human'))
tabq.render(gym.make("CliffWalking-v0", render_mode = 'human'), 10, 'results/qvalues.pkl')