# Deep Sea Treasure Envs

In [1]:
from collections import Counter
from typing import Tuple
from time import time
import base64
import os 

from IPython.display import clear_output
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np

import mo_gymnasium as mo_gym
import gymnasium as gym

import torch.nn.functional as F
import torch.optim as optim
import torch.nn as nn
import torch as T

import warnings
warnings.filterwarnings('ignore')

### Agent Brain

In [4]:
class ReplayBuffer:
    """
        A replay buffer class for storing and sampling transitions for reinforcement learning.
    """

    def __init__(self, max_size: int, input_shape: list) -> None:
        """
            Initializes the ReplayBuffer class.

            Parameters:
                - max_size (int): The maximum size of the replay buffer.
                - input_shape (list): The shape of the input state.

            Returns:
                - None
        """ 

        self.mem_size = max_size
        self.mem_cntr = 0

        self.state_memory = np.zeros(
            (self.mem_size, *input_shape),
            dtype=np.float32
        )

        self.new_state_memory = np.zeros(
            (self.mem_size, *input_shape),
            dtype=np.float32
        )

        self.action_memory = np.zeros(self.mem_size, dtype=np.int64)
        self.reward_memory = np.zeros(self.mem_size, dtype=np.float32)

        # Mask to discount potential features rewards that may come after the current state
        self.terminal_memory = np.zeros(self.mem_size, dtype=bool)

    def store_transition(self, state, action, reward: float, state_, done: bool) -> None:
        """
            Stores a transition in the replay memory.

            Parameters:
                - state (np.array): The current state of the environment.
                - action (int): The action taken in the current state.
                - reward (float): The reward received for taking the action.
                - state_ (np.array): The next state of the environment.
                - done (bool): Indicates whether the episode is done after taking the action.

            Returns:
                - None
        """

        # Index of first free memory
        index = self.mem_cntr % self.mem_size

        # Stores the transition on the memories in the indices in the appropriate arrays
        self.state_memory[index] = state
        self.new_state_memory[index] = state_

        self.action_memory[index] = action
        self.reward_memory[index] = reward

        self.terminal_memory[index] = done

        self.mem_cntr += 1

    def sample_buffer(self, batch_size: int) -> tuple:
        """
            Randomly samples a batch of transitions from the replay memory buffer.

            Args:
                batch_size (int): The number of transitions to sample.

            Returns:
                tuple: A tuple containing the sampled states, actions, rewards, next states, and terminal flags.
        """
        
        max_mem = min(self.mem_cntr, self.mem_size)

        batch = np.random.choice(max_mem, batch_size, replace=False)

        states = self.state_memory[batch]
        actions = self.action_memory[batch]
        rewards = self.reward_memory[batch]
        states_ = self.new_state_memory[batch]
        terminal = self.terminal_memory[batch]

        return states, actions, rewards, states_, terminal

In [5]:
class DuelingDeepNetwork(nn.Module):
    """
        A class for a dueling deep neural network for reinforcement learning.
    """

    def __init__(self, learning_rate: float, n_actions: int, input_dims: list, name: str, chkpt_dir: str) -> None:
        """
            Initializes the DuelingDeepNetwork class.

            Parameters:
                - learning_rate (float): The learning rate for the optimizer.
                - n_actions (int): The number of actions in the environment.
                - input_dims (list): The dimensions of the input state.
                - name (str): The name of the network.
                - chkpt_dir (str): The directory to save the network's checkpoints.

            Returns:
                - None
        """

        super(DuelingDeepNetwork, self).__init__()

        self.name = name

        self.chkpt_dir = chkpt_dir
        self.chkpt_file = os.path.join(self.chkpt_dir, self.name)
        
        self.fc1 = nn.Linear(*input_dims, 512)

        self.value = nn.Linear(512, 1)
        self.advantage = nn.Linear(512, n_actions)

        self.optimizer = optim.Adam(self.parameters(), lr=learning_rate)
        
        self.loss = nn.MSELoss()

        self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu')
        self.to(self.device)

    def forward(self, state: T.Tensor) -> Tuple[T.Tensor, T.Tensor]:
        """
            Performs a forward pass on the network.

            Parameters:
                - state (T.Tensor): The input state.

            Returns:
                tuple[T.Tensor, T.Tensor]: The value and advantage outputs of the network.
        """

        x = F.relu(self.fc1(state))

        return self.value(x), self.advantage(x)

    def save_checkpoint(self) -> None:
        """
            Saves the network's checkpoint.

            Returns:
                - None
        """

        print('\tSaving checkpoint...')
        T.save(self.state_dict(), self.chkpt_file)

    def save_best(self, final_state: tuple) -> None:
        """
            Saves the network's checkpoint with the best score for a given final state.

            Parameters:
                - final_state (tuple): The final state of the environment.

            Returns:
                - None
        """

        print(f'\tSaving {self.name} with best score...')
        T.save(
            self.state_dict(), 
            os.path.join(self.chkpt_dir, f'{self.name}_best_{final_state}')
        )

    def load_checkpoint(self) -> None:
        """
            Loads the network's checkpoint file.

            Returns:
                - None
        """
        
        print('Loading checkpoint...')
        self.load_state_dict(T.load(f'{self.chkpt_file}_best'))

In [10]:
class Agent:
    """
        The Agent class represents an agent that interacts with the environment and learns to make decisions.
    """
    
    def __init__(
        self, gamma: float, epsilon: float, learning_rate: float, n_actions: int, 
        input_dims: list, mem_size: int, batch_size: int, 
        eps_min: float = 0.01 , eps_decay: float = 5e-7, 
        replace: int = 1000, 
        chkpt_dir: str = 'dddqn_bk'
    ) -> None:
        """
            Initializes the Agent object.

            Args:
                - gamma (float): Discount factor for future rewards.
                - epsilon (float): Exploration rate, determines the probability of taking a random action.
                - learning_rate (float): Learning rate for the neural network optimizer.
                - n_actions (int): Number of possible actions in the environment.
                - input_dims (list): Dimensions of the input state.
                - mem_size (int): Size of the replay memory buffer.
                - batch_size (int): Number of samples to train on in each learning iteration.
                - eps_min (float, optional): Minimum value for epsilon. Defaults to 0.01.
                - eps_decay (float, optional): Decay rate for epsilon. Defaults to 5e-7.
                - replace (int, optional): Number of steps before updating the target network. Defaults to 1000.
                - chkpt_dir (str, optional): Directory to save checkpoints. Defaults to 'backup'.
        """

        self. epsilon = epsilon
        self.lr = learning_rate
        self.gamma = gamma

        self.input_dims = input_dims
        self.n_actions = n_actions

        self.batch_size = batch_size
        self.mem_size = mem_size

        self.eps_decay = eps_decay
        self.eps_min = eps_min

        self.replace_target_cnt = replace
        self.learn_step_cnt = 0

        self.chkpt_dir = chkpt_dir

        self.action_space = [ action for action in range(self.n_actions) ]
        self.memory = ReplayBuffer(self.mem_size, self.input_dims)

        self.q_eval = DuelingDeepNetwork(
            self.lr, self.n_actions, self.input_dims,
            'dst_dddqn_q_eval',
            self.chkpt_dir
        )

        self.q_next = DuelingDeepNetwork(
            self.lr, self.n_actions, self.input_dims,
            'dst_dddqn_q_next',
            self.chkpt_dir
        )

    def choose_action(self, observation: tuple) -> Tuple[np.array, str]:
        """
            Choose an action based on the given observation.

            Parameters:
                observation (list): The current observation.

            Returns:
                tuple[int, str]: A tuple containing the chosen action and its type.
                The first element is the action (an integer), and the second element is the action type (a string).
        """

        if np.random.random() > self.epsilon:
            # NN action
            state = T.tensor(np.array([observation]), dtype=T.float).to(self.q_eval.device)

            _, advantage = self.q_eval.forward(state)
            
            action = T.argmax(advantage).item()
            action_type = 'NN'

        else:
            # Random action
            action = np.random.choice(self.action_space)
            action_type = 'Rand'

        return action, action_type

    def store_transition(self, state, action, reward: float, state_, done: bool) -> None:
        """
            Stores a transition in the replay memory buffer.

            Parameters:
                - state (np.array): The current state of the environment.
                - action (int): The action taken in the current state.
                - reward (float): The reward received for taking the action.
                - state_ (np.array): The next state of the environment.
                - done (bool): Indicates whether the episode is done after taking the action.

            Returns:
                - None
        """

        self.memory.store_transition(state, action, reward, state_, done)

    def replace_target_network(self) -> None:
        """
            Replaces the target network with the evaluation network.

            This method is called periodically to update the target network with the weights of the evaluation network.
            The target network is used to estimate the Q-values for the next state during the training process.

            Returns:
                None
        """

        if self.learn_step_cnt % self.replace_target_cnt == 0:
            self.q_next.load_state_dict(self.q_eval.state_dict())
            self.learn_step_cnt = 0

    def decrement_epsilon(self) -> None:
        """
            Decrements the value of epsilon by eps_decay if epsilon is greater than eps_min.
            If epsilon is already less than or equal to eps_min, it is set to eps_min.

            Returns:
                None
        """

        self.epsilon = self.epsilon - self.eps_decay if self.epsilon > self.eps_min else self.eps_min

    def save_models(self) -> None:
        """
            Saves the models' checkpoints.

            Returns:
                None
        """

        self.q_eval.save_checkpoint()
        self.q_next.save_checkpoint()

    def save_best(self, final_state: tuple) -> None:
        """
            Saves the models' checkpoints with the best score for a given final state.

            Parameters:
                - final_state (tuple): The final state of the environment.

            Returns:
                None
        """

        self.q_eval.save_best(final_state)
        self.q_next.save_best(final_state)

    def load_models(self) -> None:
        """
            Loads the models' checkpoints.

            Returns:
                None
        """

        self.q_eval.load_checkpoint()
        self.q_next.load_checkpoint()

    def learn(self) -> float:
        """
            Performs the learning process by randomly sampling the memory buffer to retrieve a batch_size sequence of actions.
            It then applies the learning equations to update the network weights.

            Returns:
                float: The loss value after the learning process.
        """
        
        # Wait until there have been batch size memory episodes 
        if self.memory.mem_cntr < self.batch_size:
            return np.nan

        self.q_eval.optimizer.zero_grad()

        self.replace_target_network()

        state, action, reward, next_state, done = self.memory.sample_buffer(self.batch_size)

        states  = T.tensor(state).to(self.q_eval.device)
        actions = T.tensor(action).to(self.q_eval.device)
        rewards = T.tensor(reward).to(self.q_eval.device)
        states_ = T.tensor(next_state).to(self.q_eval.device)
        dones = T.tensor(done).to(self.q_eval.device)

        indices = np.arange(self.batch_size)

        V_s, A_s = self.q_eval.forward(states)
        
        V_s_eval, A_s_eval = self.q_eval.forward(states_)

        V_s_, A_s_ = self.q_next.forward(states_)

        q_pred = T.add(V_s, (A_s - A_s.mean(dim=1, keepdim=True)))[indices, actions]

        q_next = T.add(V_s_, (A_s_ - A_s_.mean(dim=1, keepdim=True)))
        q_eval = T.add(V_s_eval, (A_s_eval - A_s_eval.mean(dim=1, keepdim=True)))

        max_actions = T.argmax(q_eval, dim=1)

        # Value rewards for which the next state is terminal
        q_eval[dones] = 0.0

        q_target = rewards + self.gamma * q_next[indices, max_actions]

        loss = self.q_eval.loss(q_target, q_pred).to(self.q_eval.device)
        loss.backward()
        self.q_eval.optimizer.step()
        self.learn_step_cnt += 1

        self.decrement_epsilon()

        return loss.item()

### Helper Functions

In [7]:
def increment_heatmap(heatmap: np.array, episode_path: list[tuple]) -> np.array:
    """
        Increments the heatmap of visited cells based on the given episode path.

        Args:
            heatmap (np.array): The heatmap to be incremented.
            episode_path (list[tuple]): The path taken during the episode.

        Returns:
            None, the heatmap is modified in place.
    """

    positions_count = Counter(episode_path)
    del positions_count[(0, 0)]

    for (row, col), count in positions_count.items():
        heatmap[row][col] += count

def plot_learning(
    scores: list[float], 
    epsilons: list[float], 
    losses: list[float], 
    actions_history: list[dict], 
    heatmap: np.array,
    converged_episodes: dict[tuple, int], 
    env_img: np.array,
    filename: str = None
) -> None:
    """
        Plots the learning progression and visualizes the environment state.

        Args:
            scores (list[float]): List of scores for each episode.
            epsilons (list[float]): List of epsilon values for each episode.
            losses (list[float]): List of mean episode losses for each episode.
            actions_history (list[dict]): List of dictionaries containing the count of random and neural network actions for each episode.
            heatmap (np.array): 2D array representing the position visitation heatmap.
            converged_episodes (dict[tuple, int]): Dictionary mapping final states to the episode number at which they converged.
            env_img (np.array): 2D array representing the final environment state.
            filename (str, optional): Output filename for saving the plot. Defaults to None.

        Returns:
            None
    """
    
    _, axes = plt.subplots(ncols=4, figsize=(18, 4))

    axes[0].plot(scores, color='C0')

    if len(converged_episodes):
        for final_state, episode in converged_episodes.items():
            axes[0].axvline(
                episode, alpha=0.5, 
                ls='--', c=np.random.random(3,), 
                label=f'Converged to {final_state}'
            )

        # axes[0].legend(loc='center left', bbox_to_anchor=(-0.8, 0.5))
        axes[0].legend(loc='lower right')
    
    axes[0].set(
        title='Score progression',
        xlabel='Episode',
        ylabel='Score'
    )

    axes[1].plot(epsilons, color='C1')
    axes[1].set(
        title=r'$\epsilon$ progression',
        xlabel='Episode',
        ylabel=r'$\epsilon$'
    )

    axes[2].plot(losses, color='C2')
    axes[2].set(
        title='Loss progression',
        xlabel='Episode',
        ylabel='Mean Episode Loss'
    )

    df = pd.DataFrame(actions_history, columns=['Rand', 'NN']).fillna(0)
    
    axes[3].plot(df.Rand, c='C3', label='Random')
    axes[3].plot(df.NN, c='C4',label='Neural\nNetwork')
    
    # axes[3].legend(loc='center left', bbox_to_anchor=(1, 0.5))
    axes[3].legend()
    axes[3].set(
        title='Type of actions performed',
        xlabel='Episode',
        ylabel='Quantity',
        yscale='log'
    )
    
    plt.tight_layout()
    
    if filename:
        plt.savefig(filename.replace('.png', '_histories.png'))

    plt.show()

    _, axes = plt.subplots(ncols=2, figsize=(11, 5))

    sns.heatmap(
        heatmap, 
        annot=True, 
        annot_kws={'fontsize': 5},
        square=True,
        fmt='.2g', 
        linewidth=.5, 
        cmap=sns.color_palette("blend:#01153E,#FFA266", as_cmap=True),
        ax=axes[0]
    )

    axes[0].xaxis.tick_top()
    axes[0].set(
        title='Position visitation',
        xlabel='$x$',
        ylabel='$y$'
    )

    axes[1].imshow(env_img)
    axes[1].set_title("Episode's Final Env State")
    axes[1].xaxis.set_visible(False)
    axes[1].yaxis.set_visible(False)
    
    plt.tight_layout()
    
    if filename:
        plt.savefig(filename.replace('.png', '_heatmap.png'))

    plt.show()

In [8]:
def check_for_unblock(env: gym.Env, action: int, current_path: list, converged_paths: dict[tuple, list], converged_episodes: dict[tuple, int]) -> tuple[dict, dict]:   
    """
        Checks if the current action should unblock a previously converged path and updates the environment accordingly.

        Args:
            env (gym.Env): The environment object.
            action (int): The action to be taken.
            current_path (list): The current path taken by the agent.
            converged_paths (dict[tuple, list]): A dictionary containing the converged paths.
            converged_episodes (dict[tuple, int]): A dictionary containing the number of episodes for each converged path.

        Returns:
            tuple[dict, dict]: A tuple containing the updated converged paths and converged episodes dictionaries.
    """   
    
    next_state = tuple(env.current_state + env.dir[action])
    converged_path = converged_paths.get(next_state, False)

    # The +1 represents the additional step that must be done if the current action is carried out.
    if converged_path and len(converged_path) > len(current_path) + 1:

        env.sea_map[next_state[0], next_state[1]] = env.treasures[next_state]   

        new_converged_paths = { key: value for key, value in converged_paths.items() if key != next_state }
        new_converged_episodes = { key: value for key, value in converged_episodes.items() if key != next_state }

        return new_converged_paths, new_converged_episodes
    
    return converged_paths, converged_episodes

### Learning Loop

In [None]:
def learn_env(
    env: gym.Env, agent: Agent, 
    conversion_threshold: int = 300,
    n_trial: int = 1,
    load_checkpoint: bool = False, write_results: bool = False, 
    plots_dir: str = 'plots'
) -> tuple:
    """
        Learn the environment using the Deep Q-Managed algorithm.

        Args:
            env (gym.Env): The environment to learn.
            agent (Agent): The agent that learns and interacts with the environment.
            conversion_threshold (int, optional): The threshold for considering a state as converged. Defaults to 300.
            n_trial (int, optional): The number of trials. Defaults to 1.
            load_checkpoint (bool, optional): Whether to load a checkpoint. Defaults to False.
            write_results (bool, optional): Whether to write the results to files. Defaults to False.
            plots_dir (str, optional): The directory to save the plots. Defaults to 'plots'.

        Returns:
            tuple: A tuple containing the scores, epsilon history, loss history, actions history, heatmap, and converged episodes.
    """
    
    n_episodes = 100000

    if load_checkpoint:
        agent.load_models()

    scores, eps_history = [], []
    loss_history, episode_losses = [], []
    actions_history = []

    # Saves the paths taken in the episodes to check for converged states
    paths_hashs = Counter()

    # Saves the paths of conversion to a final state: { (i_final, j_final): path }
    converged_paths = { final_state: [] for final_state in env.treasures.keys() }
    
    # Saves the episode of conversion to a final state: { (i_final, j_final): episode }
    converged_episodes = {}

    # Matrix to save the position visitation
    heatmap = np.zeros(env.sea_map.shape)

    for episode in range(n_episodes):

        # Checking if the agent converged for treasures states
        if np.all(env.sea_map <= 0):
            print('Agent converged for all treasure states')
            break
        
        observation, _ = env.reset()
        done = False
        score = 0

        episode_losses = [0]
        actions_type = []
        
        episode_path = []
        episode_hash = None

        while not done: 
            action, action_type = agent.choose_action(observation)
            actions_type.append(action_type)
            
            converged_paths, converged_episodes = check_for_unblock(
                env, int(action), 
                episode_path, 
                converged_paths, converged_episodes
            )

            next_observation, reward, done, _, _ = env.step(action)

            episode_path.append(tuple(env.current_state))

            score += reward

            agent.store_transition(observation, action, reward, next_observation, done)
            loss = agent.learn()
            
            episode_losses.append(loss)

            observation = next_observation

            if len(actions_type) == 1000:
                done = True
        
        else:
            increment_heatmap(heatmap, episode_path)

            episode_hash = hash(str(episode_path))
            paths_hashs[episode_hash] += 1

        scores.append(score)    
        eps_history.append(agent.epsilon)
        loss_history.append(np.nanmean(episode_losses))
        actions_history.append(dict(Counter(actions_type)))

        _, hash_count = paths_hashs.most_common(1)[0]
        if hash_count >= conversion_threshold:
            
            converged_state = episode_path[-1]

            print(f'Converged to state {converged_state}')

            # Saving episode of conversion
            converged_episodes[converged_state] = episode

            # Saving the paths of conversion
            converged_paths[converged_state] = episode_path

            # Blocking converged state
            env.sea_map[converged_state[0], converged_state[1]] = -10

            # Increasing agent randomness
            agent.epsilon = 0.3   
            agent.eps_decay = 1e-3

            # Resetting paths taken in the episodes 
            paths_hashs = Counter()
            
            if write_results:
                agent.save_best(converged_state)  
                np.savez(
                    f"{env.name.lower()}_numpys/converged_{converged_state}.npz",
                    scores=scores,
                    eps_history=eps_history, 
                    loss_history=loss_history, 
                    actions_history=actions_history, 
                    heatmap=heatmap, 
                    converged_episodes=converged_episodes,
                )
    
                # Saves the plots for the converged state
                env.reset()
                plot_learning(
                    scores, eps_history, loss_history, 
                    actions_history, heatmap, converged_episodes, 
                    env.render(), f'{plots_dir}/plots_converged_{converged_state}.png'
                )

        print(f'\nLatest episode length: {len(episode_path)}')
        print(f'Paths hashs: {len(paths_hashs)}\n{paths_hashs}\n')

        print(f'Best path lengths: {env.best_paths_lengths}')
        print(f'Converged lengths: { { final_state: len(path) for final_state, path in sorted(converged_paths.items()) } }')
        print(f'Converged paths: {converged_paths}')
        print(f'Converged episodes: {converged_episodes}')
        print(f'Episode {episode} of {n_episodes}\n\tScore: {score:.2f} AVG Score: {np.mean(scores[-50:]):.2f} Mean Loss: {loss_history[-1]:3f} Epsilon: {agent.epsilon:5f}')
        print('\tActions taken in episode: NN: {NN}, Rand: {Rand}'.format_map(Counter(actions_type)))
        print(f'\tFinal state: {tuple(observation)}')

        clear_output(wait=True)
        
        # plot_learning(scores, eps_history, loss_history, actions_history, heatmap, converged_episodes, env.render())

        if write_results:
            env_name = env.name.lower()
            with open(f'{env_name}_solutions/solution_{env_name}_{conversion_threshold}_{n_trial}.txt', 'a') as solution_file:
                discovered_front = ' '.join([ 
                    f'{-1 * len(path)} {int(env.treasures[final_state])}' if len(path) else '0 0' 
                    for final_state, path in sorted(converged_paths.items()) 
                ])

                solution_file.write(f'{discovered_front}\n')

    return scores, eps_history, loss_history, actions_history, heatmap, converged_episodes

### Deep Sea Treasure (DST)

In [None]:
dst_env = mo_gym.make('deep-sea-treasure-v1', render_mode='rgb_array')
dst_env = mo_gym.LinearReward(dst_env, weight=np.array([ 0.5, 0.5 ]))

agent = Agent(
    gamma=0.99, 
    epsilon=1.0, eps_decay=3e-3,
    learning_rate=1e-4, 
    n_actions=4, 
    input_dims=[2], 
    mem_size=10000, 
    batch_size=10, 
    replace=500,
    chkpt_dir = 'dst_bk'
)

In [12]:
start_time = time()

scores, eps_history, loss_history, actions_history, heatmap, converged_episodes = learn_env(dst_env, agent, 100, 1, False, False)

end_time = time()

print(f'Elapsed time: {(end_time - start_time)/60:.1f} min')

In [None]:
dst_env.reset()

plot_learning(
    scores, 
    eps_history, loss_history, 
    actions_history, 
    heatmap,
    converged_episodes, 
    dst_env.render()
)

### Bountiful Sea Treasure (BST)

In [None]:
bst_env = mo_gym.make('bountiful-sea-treasure-v1', render_mode='rgb_array')
bst_env = mo_gym.LinearReward(bst_env, weight=np.array([ 0.5, 0.5 ]))

agent = Agent(
    gamma=0.99, 
    epsilon=1.0, eps_decay=3e-3,
    learning_rate=1e-4, 
    n_actions=4, 
    input_dims=[2], 
    mem_size=10000, 
    batch_size=10, 
    replace=500,
    chkpt_dir = 'bst_bk'
)

In [None]:
start_time = time()

bst_scores, bst_eps_history, bst_loss_history, bst_actions_history, bst_heatmap, bst_converged_episodes = learn_env(bst_env, agent, load_checkpoint=False, plots_dir='bst_plots')

end_time = time()

print(f'Elapsed time: {(end_time - start_time)/60:.1f} min')

In [None]:
bst_env.reset()

plot_learning(
    bst_scores, 
    bst_eps_history, loss_history, 
    bst_actions_history, 
    bst_heatmap,
    bst_converged_episodes, 
    bst_env.render()
)

### Modified Bountiful Sea Treasure (MBST)

In [None]:
mbst_env = mo_gym.make('modified-bountiful-sea-treasure-v1', render_mode='rgb_array')
mbst_env = mo_gym.LinearReward(mbst_env, weight=np.array([ 0.5, 0.5 ]))

agent = Agent(
    gamma=0.99, 
    epsilon=1.0, eps_decay=3e-3,
    learning_rate=1e-4, 
    n_actions=4, 
    input_dims=[2], 
    mem_size=10000, 
    batch_size=10, 
    replace=500,
    chkpt_dir = 'mbst_bk'
)

In [None]:
start_time = time()

mbst_scores, mbst_eps_history, mbst_loss_history, mbst_actions_history, mbst_heatmap, mbst_converged_episodes = learn_env(mbst_env, agent, False, 'mbst_plots')

end_time = time()

print(f'Elapsed time: {(end_time - start_time)/60:.1f} min')

In [None]:
mbst_env.reset()

plot_learning(
    mbst_scores, 
    mbst_eps_history, loss_history, 
    mbst_actions_history, 
    mbst_heatmap,
    mbst_converged_episodes, 
    mbst_env.render()
)