Model design: if we have a game board of length 20, and 3 possible actions, then we have $20 \times 20 \times 3 = 1200$ possible $\left(x_\mathrm{terminal}, x, a\right)$ configurations. Question is how we want to batch these up? 

The impact of not providing fully i.i.d. inputs is not entirely clear to me. In principle this can bias the gradient updates and lead to e.g. stable orbits or other unwanted learning trajectories which do not converge on a local value optimum. However, we have some lee-way here. We want our gradient update to be an unbiased estimator for $\nabla_\omega\mathbb{E}_{x_\mathrm{terminal}, x, a}\left[\mathcal{L}\left(x_\mathrm{terminal},x,a\right)\right]$, which is equivalent to $\Sigma_{x_\mathrm{terminal}, x, a}p\left(x_\mathrm{terminal},x,a\right)\nabla_\omega\mathcal{L}\left(x_\mathrm{terminal},x,a\right)$ because our state-action probability does not depend upon $\mathcal{L}$. This expectation does not need to be estimated by sampling all combinations of $\left(x_\mathrm{terminal},x,a\right)$. If we know the probability of each state-action pair, we can calculate this gradient exactly using this summation. Or we can choose to sample some variables and sum over others, assuming that we can calculate the marginal probabilities. Fortunately we know these to be uniform according our environment (determines the distribution of $x_\mathrm{terminal}$) and according to our exploration policy (determines the distribution of $x,a$).

Our q-model will have the form $x_\mathrm{terminal}, x \rightarrow q\left(a\right)$ with a dataset size of $400$ and an ouput size of $3$. Let us assume that we will not be updating the bootstraps during an epoch. From the perspective of evaluating the bootstrap model, we will process all $1200$ datapoints at the start of the epoch to maximise parallelisation (this can be done using $400$ parallel forward passes). From the perspective of the bootstrap estimates, we therefore do not care how datapoints are batched during an epoch.

For an $n$-step return, our loss function takes the form
\begin{split}
  \mathcal{L}\left(x,a\right) ~&=~ \left| q_\mathrm{target}\left(x,a\right) - q_\mathrm{model}\left(x,a\right)  \right|^2 \\
  q_\mathrm{target}\left(x_0,a_0\right) ~&=~ r_{x_0,a_0} ~+~ \gamma \cdot r_{x_1,\mathrm{argmax}_{a_1}q_\mathrm{model}\left(x_1, a_1\right)} ~+~ \dots ~+~ \gamma^n \cdot q_\mathrm{bootstrap}\left(x_n, \mathrm{argmax}_{a_n} q_\mathrm{model}\left(x_n, a_n\right)\right) \\
\end{split}

Since a single evaluation of $q_\mathrm{model}\left(x,a\right)$ provides $q$ for all values of $a$, it makes sense to update all $a$ for a given $\left(x_\mathrm{terminal},x\right).

We must also determine whether to sample over $x_\mathrm{terminal}$ and/or $x$. It would be desirable to sample+batch in some way, to avoid updating all $1200$ permutations at once. In fact this number may increase  if we change the number of actions or size of the game board.

The complication is that, for every batch update, we want to calculate $q_\mathrm{target}$ using the latest model of $q_\mathrm{model}$ to select the actions for every step into the future (empirical or bootstrap). The only way to avoid doing datapoint-by-datapoint evaluation is to choose a single value of $x_\mathrm{terminal}$, evaluate $q_\mathrm{model}$ for all $x$, then choose a batch of $x$ over which to evaluate $x_\mathrm{target}$. In fact, I think this is only i.i.d if we update _all_ $x$ for that value of $x_\mathrm{terminal}$. This leads to batch sizes of $N_s \times N_a$, which equals $60$ for our base config. This is already large, and gets larger if we increase $N_s$. 

I think it is therefore more correct to simply create batches of fixed size $N_\mathrm{batch}$ which sample over both $x_\mathrm{terminal}$ and $x$. The number of updates will then be $N_\mathrm{batch}\times N_a$, as we still sum over all actions. For a given $x_\mathrm{terminal}$, I think we should still evaluate $q_\mathrm{model}$ for all $x$, which then allows us to calculate $q_\mathrm{target}$ for that datapoint without any further passes through the network. Per batch, we therefore implement $N_\mathrm{batch}$ forward passes of length $N_s$.


In [1]:
#  Required imports

import math, os, pickle, sys, time

import numpy as np

from matplotlib import pyplot as plt
from matplotlib.ticker import MultipleLocator

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers    import Conv2D, Concatenate, Dense, Dropout, Flatten, Input, MaxPooling2D, Rescaling
from tensorflow.keras.models    import Model
from tensorflow.keras.callbacks import EarlyStopping

print("TensorFlow has found devices:")
for device in tf.config.list_physical_devices() :
    print(f"-  {device}")
    
  # create global list of all threads we will create
all_threads = []




TensorFlow has found devices:
-  PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU')
-  PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')


In [2]:
###
###  Utility functions
###

def generate_directory_for_file_path(fname, print_msg_on_dir_creation=True) :
    """
    Create the directory structure needed to place file fname. Call this before fig.savefig(fname, ...) to 
    make sure fname can be created without a FileNotFoundError
    Input:
       - fname: str
                name of file you want to create a tree of directories to enclose
                also create directory at this path if fname ends in '/'
       - print_msg_on_dir_creation: bool, default = True
                                    if True then print a message whenever a new directory is created
    """
    while "//" in fname :
        fname = fname.replace("//", "/")
    dir_tree = fname.split("/")
    dir_tree = ["/".join(dir_tree[:i]) for i in range(1,len(dir_tree))]
    dir_path = ""
    for dir_path in dir_tree :
        if len(dir_path) == 0 : continue
        if not os.path.exists(dir_path) :
            os.mkdir(dir_path)
            if print_msg_on_dir_creation :
                print(f"Directory {dir_path} created")
            continue
        if os.path.isdir(dir_path) : 
            continue
        raise RuntimeError(f"Cannot create directory {dir_path} because it already exists and is not a directory")
    



In [3]:
###
###  Class to allow configuration and interation with the environment
###

class Environment :
    
    def __init__(self, num_states=10, actions_dx = [-1, 0, 1], r_per_turn=-1., r_per_dx=1., r_per_b=-1., gamma=.99) :
        self.num_states = None
        self.actions_dx = None
        self.r_per_turn = None
        self.r_per_dx   = None
        self.r_per_b    = None
        self.gamma      = None
        self.configure_states_and_actions(num_states, actions_dx)
        self.configure_returns(r_per_turn, r_per_dx, r_per_b, gamma)
        
    def __str__(self) :
        return self.get_summary()
        
    def configure_states_and_actions(self, num_states=None, actions_dx=None) :
        if num_states is None : num_states = self.num_states
        if actions_dx is None : actions_dx = self.actions_dx
        actions_dx = np.array(actions_dx).flatten()
        if len(actions_dx) != len(set(actions_dx)) :
            raise RuntimeError(f"actions_dx contains duplicate entries ({actions_dx})")
        self.num_states  = num_states
        self.x_min       = 0
        self.x_max       = num_states - 1
        self.x_range     = self.x_max - self.x_min
        self.actions_dx  = actions_dx
        self.num_actions = len(actions_dx)
        self.a_min       = actions_dx.min()
        self.a_max       = actions_dx.max()
        self.a_range     = self.a_max - self.a_min
        
    def configure_returns(self, r_per_turn=None, r_per_dx=None, r_per_b=None, gamma=None) :
        if r_per_turn is None : r_per_turn = self.r_per_turn
        if r_per_dx   is None : r_per_dx   = self.r_per_dx
        if r_per_b    is None : r_per_b    = self.r_per_b
        if gamma      is None : gamma      = self.gamma
        self.r_per_turn = r_per_turn
        self.r_per_dx   = r_per_dx
        self.r_per_b    = r_per_b
        self.gamma      = gamma
    
    def enforce_action_dx_is_valid(self, a_dx) :
        if self.is_action_dx_out_of_bounds(a_dx) :
            raise RuntimeError(f"action dx ({a_dx}) not in allowed list ({self.actions_dx})")
    
    def enforce_action_idx_is_valid(self, a_idx) :
        if self.is_action_idx_out_of_bounds(a_idx) :
            raise RuntimeError(f"action index ({a_idx}) is out of bounds (0-{self.num_actions-1})")
        
    def enforce_state_is_valid(self, x, x_terminal=None) :
        if not (x_terminal is None) :
            if x == x_terminal :
                raise RuntimeError(f"cannot perform action because initial position ({x}) is terminal")
        if self.is_x_out_of_bounds(x) :
            raise RuntimeError(f"cannot perform action because initial position ({x}) is out of bounds")
        return True
    
    def get_action_dx_from_index(self, a_idx) :
        return self.actions_dx[a_idx]
    
    def get_action_index_from_dx(self, a_dx) :
        return list(self.actions_dx).index(a_dx)
    
    def get_q_true(self, x_terminal) :
        q_values = np.zeros(shape=(self.num_states, self.num_actions))
        for x in range(self.num_states) :
            for a_idx in range(self.num_actions) :
                if x == x_terminal :
                    q_values[x, a_idx] = np.nan
                    continue
                g, xp  = self.perform_action(x_terminal, x, a_idx)
                step_y = self.gamma
                while not xp == x_terminal :
                    dxp     = np.fabs(x + self.actions_dx - x_terminal)
                    ap_idx  = np.argmin(dxp)
                    r, xp   = self.perform_action(x_terminal, xp, ap_idx)
                    g      += step_y * r
                    step_y *= self.gamma
                q_values[x, a_idx] = g
        return q_values
    
    def get_all_valid_x_pairs(self) :
        all_pairs = []
        for x_terminal in range(self.num_states) :
            for x in range(self.num_states) :
                if x == x_terminal : continue
                all_pairs.append((x_terminal, x))
        return np.array(all_pairs)
    
    def get_summary(self, write_to=None) :
        str_summary  =  "Environment config:\n"
        str_summary += f"    num_states  | {self.num_states}\n"
        str_summary += f"    x_min       | {self.x_min}\n"
        str_summary += f"    x_max       | {self.x_max}\n"
        str_summary += f"    num_actions | {self.num_actions}\n"
        str_summary += f"    actions_dx  | {self.actions_dx}\n"
        str_summary += f"    r_per_turn  | {self.r_per_turn}\n"
        str_summary += f"    r_per_dx    | {self.r_per_dx}\n"
        str_summary += f"    r_per_b     | {self.r_per_b}\n"
        str_summary += f"    gamma       | {self.gamma}\n"
        str_summary += f"    Ns*Ns*Na    | {self.num_states*self.num_states*self.num_actions}\n"
        if not (write_to is None) : 
            write_to.write(str_summary)
        return str_summary
        
    def is_action_idx_out_of_bounds(self, a_idx) :
        a_idx = np.array(a_idx)
        return np.logical_or(a_idx < 0, a_idx >= self.num_actions)
        
    def is_action_dx_out_of_bounds(self, a_dx) :
        if type(a_dx) in [list, set, np.ndarray] :
            return np.array([self.is_action_dx_out_of_bounds(a_dxp) for a_dxp in a_dx])
        return a_dx not in self.actions_dx
        
    def is_x_out_of_bounds(self, x) :
        if x < self.x_min : return True
        if x > self.x_max : return True
        return False
         
    def perform_action(self, x_terminal, x, a_idx) :
        self.enforce_state_is_valid(x, x_terminal)
        self.enforce_action_idx_is_valid(a_idx)
        #  Iterate agent position, if hit boundary then add penalty and return to original position 
        x_p      = x + self.actions_dx[a_idx]
        reward_b = 0
        if self.is_x_out_of_bounds(x_p) :
            reward_b = self.r_per_b
            x_p      = x
        #  Get distance-based reward
        dx        = np.fabs(x   - x_terminal)
        dx_p      = np.fabs(x_p - x_terminal)
        reward_dx = self.r_per_dx * (dx - dx_p)
        #  Return total reward and new agent state
        reward_tot = self.r_per_turn + reward_b + reward_dx
        return reward_tot, x_p
    
    def print_summary(self, write_to=None) :
        if write_to is None : write_to = sys.stdout
        self.get_summary(write_to=write_to)
          

In [4]:
###
###  Methods for defining and manipulating keras q-value models
###


def create_q_model(env, name=None) :
    x_range, x_min, x_max = env.x_range, env.x_min, env.x_max
    input_layer    = Input((2,))
    next_layer     = Rescaling(2./x_range, offset=-(x_max+x_min)/x_range)(input_layer)
    next_layer     = Dense(100, activation="elu")(next_layer)
    next_layer     = Dense(400, activation="elu")(next_layer)
    next_layer     = Dense(100, activation="elu")(next_layer)
    output_layer   = Dense(env.num_actions, activation="linear")(next_layer)
    model          = Model(input_layer, output_layer, name=name)
    model.compile(loss="mse", optimizer="sgd")
    return model


def initialise_keras_objects(env, learning_rate, optimizer_type="sgd", tag="no_tag") :
    loss_fcn  = tf.keras.losses.MeanSquaredError()
    q_model   = create_q_model(env, name=f"q_model_{tag}")
    bs_model  = create_q_model(env, name=f"q_model_{tag}_clone")
    if optimizer_type.lower() == "sgd" :
        optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)
    elif optimizer_type.lower() == "adam" :
        optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
    else :
        raise NotImplementedError(f"optimizer_type = {optimizer_type} not recognised by method initialise_keras_objects")
    return loss_fcn, q_model, bs_model, optimizer


def get_mse(q_values_1, q_values_2) :
    q_res = q_values_2 - q_values_1
    q_res = np.where(np.isfinite(q_res), q_res, 0)
    q_res = q_res**2
    return np.mean(q_res)


In [17]:

class Experiment() :
    
    def __init__(self, run_config_dict=None, **kwargs) :
        if run_config_dict is None :
            run_config_dict = {}
        self.configure(run_config_dict={}, **kwargs)
        
    def configure(self, run_config_dict, **kwargs) :
        self.max_mse_true               = self.resolve_argument("max_mse_true"               , -np.inf , np.float32, run_config_dict, **kwargs)
        self.max_epochs                 = self.resolve_argument("max_epochs"                 , -1      , np.int32  , run_config_dict, **kwargs)
        self.batch_size                 = self.resolve_argument("batch_size"                 , 1       , np.int32  , run_config_dict, **kwargs)
        self.optimizer_type             = self.resolve_argument("optimizer_type"             , "sgd"   , str       , run_config_dict, **kwargs)
        self.learning_rate              = self.resolve_argument("learning_rate"              , 1e-3    , np.float32, run_config_dict, **kwargs)
        self.plot_estimate_after_epochs = self.resolve_argument("plot_estimate_after_epochs" , 5       , np.int32  , run_config_dict, **kwargs)
        self.plot_monitors_after_epochs = self.resolve_argument("plot_monitors_after_epochs" , 5       , np.int32  , run_config_dict, **kwargs)
        self.save_objects_after_epochs  = self.resolve_argument("save_objects_after_epochs"  , 1       , np.int32  , run_config_dict, **kwargs)
        self.act_from_true_greedy_policy= self.resolve_argument("act_from_true_greedy_policy", False   , bool      , run_config_dict, **kwargs)
        self.clone_after_epochs         = self.resolve_argument("clone_after_epochs"         , 5       , np.int32  , run_config_dict, **kwargs)
        self.priority_weight            = self.resolve_argument("priority_weight"            , 0.      , np.float32, run_config_dict, **kwargs)
        self.num_step_returns           = self.resolve_argument("num_step_returns"           , 1       , np.int32  , run_config_dict, **kwargs)
        self.run_tag                    = self.resolve_argument("run_tag"                    , "no_tag", str       , run_config_dict, **kwargs)
        self.run_idx                    = self.resolve_argument("run_idx"                    , 0       , np.int32  , run_config_dict, **kwargs)
        self.set_derived_constants()
        
    def export_run_config_dict(self) :
        run_config_dict = {}
        run_config_dict["max_mse_true"               ] = self.max_mse_true
        run_config_dict["max_epochs"                 ] = self.max_epochs
        run_config_dict["batch_size"                 ] = self.batch_size
        run_config_dict["optimizer_type"             ] = self.optimizer_type
        run_config_dict["learning_rate"              ] = self.learning_rate
        run_config_dict["plot_estimate_after_epochs" ] = self.plot_estimate_after_epochs
        run_config_dict["plot_monitors_after_epochs" ] = self.plot_monitors_after_epochs
        run_config_dict["save_objects_after_epochs"  ] = self.save_objects_after_epochs
        run_config_dict["act_from_true_greedy_policy"] = self.act_from_true_greedy_policy
        run_config_dict["clone_after_epochs"         ] = self.clone_after_epochs
        run_config_dict["priority_weight"            ] = self.priority_weight
        run_config_dict["num_step_returns"           ] = self.num_step_returns
        run_config_dict["run_tag"                    ] = self.run_tag
        run_config_dict["run_idx"                    ] = self.run_idx
        return run_config_dict
        
    def create_config(self, env, verbose=True) :
        #  Create config message
        config_str = ""
        config_str += f"="*114 + "\n"
        config_str += env.get_summary()
        config_str += f"="*114 + "\n"
        config_str += self.get_summary()
        config_str += f"="*114 + "\n"
        #  Print to stdout
        if verbose :
            print(config_str)
        #  Print to file
        config_fname = f"{self.top_directory}/config.txt"
        generate_directory_for_file_path(config_fname, print_msg_on_dir_creation=verbose)
        with open(config_fname, "w") as config_file :
            config_file.write(config_str)
            if hasattr(self, "q_model") :
                config_file.write("\nq-value model config:\n")
                self.q_model.summary(print_fn=lambda x: config_file.write(x + '\n'))
    
    def get_summary(self, write_to=None) :
        str_summary  =  "Experiment config:\n"
        str_summary += f"    max_mse_true                | {self.max_mse_true:.7}\n"
        str_summary += f"    max_epochs                  | {self.max_epochs}\n"
        str_summary += f"    batch_size                  | {self.batch_size}\n"
        str_summary += f"    optimizer_type              | {self.optimizer_type:.7}\n"
        str_summary += f"    learning_rate               | {self.learning_rate:.7}\n"
        str_summary += f"    plot_estimate_after_epochs  | {self.plot_estimate_after_epochs}\n"
        str_summary += f"    plot_monitors_after_epochs  | {self.plot_monitors_after_epochs}\n"
        str_summary += f"    save_objects_after_epochs   | {self.save_objects_after_epochs}\n"
        str_summary += f"    act_from_true_greedy_policy | {self.act_from_true_greedy_policy}\n"
        str_summary += f"    clone_after_epochs          | {self.clone_after_epochs}\n"
        str_summary += f"    priority_weight             | {self.priority_weight:.7}\n"
        str_summary += f"    num_step_returns            | {self.num_step_returns}\n"
        str_summary += f"    run_tag                     | {self.run_tag}\n"
        str_summary += f"    run_idx                     | {self.run_idx}\n"
        str_summary += f"    num_priority_datapoints     | {(self.num_priority_datapoints)}\n"
        if not (write_to is None) : 
            write_to.write(str_summary)
        return str_summary
        
    def record_monitors(self, epoch_idx, loss, max_abs_q, mse_true) :
        self.epochs_record   .append(epoch_idx)
        self.loss_record     .append(loss)
        self.max_abs_q_record.append(max_abs_q)
        self.mse_true_record .append(mse_true)
        
    def is_well_configured(self, debug_to=None, enforce_with_exception=False) :
        return_value, debug_str = True, ""
        if self.max_epochs <= 0 and self.max_mse_true == -np.inf :
            debug_str   += "WARNING: No max_mse_true or max_epochs set - experiment will run forever (i.e. until interrupted)\n"
        if self.batch_size <= 0 :
            return_value = False
            debug_str   += f"ERROR: batch_size must be positive and non-zero ({self.batch_size} found)\n"
        if self.optimizer_type.lower() not in ["sgd", "adam"] :
            optimizer_type = False
            debug_str   += f"ERROR: optimizer_type ({self.optimizer_type}) not recognised\n"
        if self.learning_rate <= 0 :
            return_value = False
            debug_str   += f"ERROR: learning_rate must be positive and non-zero ({self.learning_rate} found)\n"
        if self.priority_weight < 0 or self.priority_weight > 1 :
            return_value = False
            debug_str   += f"ERROR: priority_weight must be between 0 and 1 ({self.priority_weight} found)\n"
        if self.clone_after_epochs <= 0 :
            return_value = False
            debug_str   += f"ERROR: bootstrap_method is {self.bootstrap_method}, so clone_after_epochs must be positive and non-zero ({self.clone_after_epochs} found)\n"
        if self.num_step_returns < 1 :
            return_value = False
            debug_str   += f"ERROR: num_step_returns must be positive and non-zero ({self.num_step_returns} found)\n"
        if self.run_idx < 0 :
            return_value = False
            debug_str   += f"ERROR: run_idx must be positive ({self.run_idx} found)\n"
        if enforce_with_exception and not return_value :
            raise RuntimeError(f"Experiment config failed with the following problems:\n{debug_str}")
        if not (debug_to is None) :
            debug_to.write(debug_str)
        return return_value
    
    def plot_value_functions(self, bs_values, q_values, save=False, show=False, close=False, dpi=100) :

        plot_x     = np.arange(self.num_states)
        start_time = time.time()
        
        test_x_vals = [int(.15*(self.num_states-1)), int(.5*(self.num_states-1)), int(.85*(self.num_states-1))]

        fig = plt.figure(figsize=(20, 12))
        fig.set_facecolor("white")
        fig.set_alpha(1)
        
        minor_locator = MultipleLocator(.5)

        #  Plot value estimates
        ax1 = fig.add_subplot(3, 4, 1)
        ax1.tick_params(axis="both", which="both", right=True, top=True, direction="in", labelsize=12)
        ax1.plot(plot_x, q_values     [test_x_vals[0],:,0], "x-" , c="r"         , ms=5, lw=1, alpha=0.5, label="Estimated")
        ax1.plot(plot_x, bs_values    [test_x_vals[0],:,0], "x-" , c="b"         , ms=5, lw=1, alpha=0.5, label="Bootstrap")
        ax1.plot(plot_x, self.q_target[test_x_vals[0],:,0], "o-" , c="darkorange", ms=5, lw=4, alpha=0.5, label="Target")
        ax1.plot(plot_x, self.q_true  [test_x_vals[0],:,0], ".--", c="gray"      , ms=5, lw=4, alpha=0.5, label="True")
        ax1.grid(which="major", color='darkgray', linestyle='-', linewidth=.5)
        ax1.grid(which="minor", color='gray'    , linestyle=':', linewidth=.5)
        ax1.yaxis.set_minor_locator(minor_locator)

        ax2 = fig.add_subplot(3, 4, 2)
        ax2.tick_params(axis="both", which="both", right=True, top=True, direction="in", labelsize=12)
        ax2.plot(plot_x, q_values     [test_x_vals[0],:,1], "x-" , c="r"         , ms=5, lw=1, alpha=0.5, label="Estimated")
        ax2.plot(plot_x, bs_values    [test_x_vals[0],:,1], "x-" , c="b"         , ms=5, lw=1, alpha=0.5, label="Bootstrap")
        ax2.plot(plot_x, self.q_target[test_x_vals[0],:,1], "o-" , c="darkorange", ms=5, lw=4, alpha=0.5, label="Target")
        ax2.plot(plot_x, self.q_true  [test_x_vals[0],:,1], ".--", c="gray"      , ms=5, lw=4, alpha=0.5, label="True")
        ax2.grid(which="major", color='darkgray', linestyle='-', linewidth=.5)
        ax2.grid(which="minor", color='gray'    , linestyle=':', linewidth=.5)
        ax2.yaxis.set_minor_locator(minor_locator)

        ax3 = fig.add_subplot(3, 4, 3)
        ax3.tick_params(axis="both", which="both", right=True, top=True, direction="in", labelsize=12)
        ax3.plot(plot_x, q_values     [test_x_vals[0],:,2], "x-" , c="r"         , ms=5, lw=1, alpha=0.5, label="Estimated")
        ax3.plot(plot_x, bs_values    [test_x_vals[0],:,2], "x-" , c="b"         , ms=5, lw=1, alpha=0.5, label="Bootstrap")
        ax3.plot(plot_x, self.q_target[test_x_vals[0],:,2], "o-" , c="darkorange", ms=5, lw=4, alpha=0.5, label="Target")
        ax3.plot(plot_x, self.q_true  [test_x_vals[0],:,2], ".--", c="gray"      , ms=5, lw=4, alpha=0.5, label="True")
        ax3.grid(which="major", color='darkgray', linestyle='-', linewidth=.5)
        ax3.grid(which="minor", color='gray'    , linestyle=':', linewidth=.5)
        ax3.yaxis.set_minor_locator(minor_locator)
        
        ax4 = fig.add_subplot(3, 4, 5)
        ax4.tick_params(axis="both", which="both", right=True, top=True, direction="in", labelsize=12)
        ax4.plot(plot_x, q_values     [test_x_vals[1],:,0], "x-" , c="r"         , ms=5, lw=1, alpha=0.5, label="Estimated")
        ax4.plot(plot_x, bs_values    [test_x_vals[1],:,0], "x-" , c="b"         , ms=5, lw=1, alpha=0.5, label="Bootstrap")
        ax4.plot(plot_x, self.q_target[test_x_vals[1],:,0], "o-" , c="darkorange", ms=5, lw=4, alpha=0.5, label="Target")
        ax4.plot(plot_x, self.q_true  [test_x_vals[1],:,0], ".--", c="gray"      , ms=5, lw=4, alpha=0.5, label="True")
        ax4.grid(which="major", color='darkgray', linestyle='-', linewidth=.5)
        ax4.grid(which="minor", color='gray'    , linestyle=':', linewidth=.5)
        ax4.yaxis.set_minor_locator(minor_locator)

        ax5 = fig.add_subplot(3, 4, 6)
        ax5.tick_params(axis="both", which="both", right=True, top=True, direction="in", labelsize=12)
        ax5.plot(plot_x, q_values     [test_x_vals[1],:,1], "x-" , c="r"         , ms=5, lw=1, alpha=0.5, label="Estimated")
        ax5.plot(plot_x, bs_values    [test_x_vals[1],:,1], "x-" , c="b"         , ms=5, lw=1, alpha=0.5, label="Bootstrap")
        ax5.plot(plot_x, self.q_target[test_x_vals[1],:,1], "o-" , c="darkorange", ms=5, lw=4, alpha=0.5, label="Target")
        ax5.plot(plot_x, self.q_true  [test_x_vals[1],:,1], ".--", c="gray"      , ms=5, lw=4, alpha=0.5, label="True")
        ax5.grid(which="major", color='darkgray', linestyle='-', linewidth=.5)
        ax5.grid(which="minor", color='gray'    , linestyle=':', linewidth=.5)
        ax5.yaxis.set_minor_locator(minor_locator)

        ax6 = fig.add_subplot(3, 4, 7)
        ax6.tick_params(axis="both", which="both", right=True, top=True, direction="in", labelsize=12)
        ax6.plot(plot_x, q_values     [test_x_vals[1],:,2], "x-" , c="r"         , ms=5, lw=1, alpha=0.5, label="Estimated")
        ax6.plot(plot_x, bs_values    [test_x_vals[1],:,2], "x-" , c="b"         , ms=5, lw=1, alpha=0.5, label="Bootstrap")
        ax6.plot(plot_x, self.q_target[test_x_vals[1],:,2], "o-" , c="darkorange", ms=5, lw=4, alpha=0.5, label="Target")
        ax6.plot(plot_x, self.q_true  [test_x_vals[1],:,2], ".--", c="gray"      , ms=5, lw=4, alpha=0.5, label="True")
        ax6.grid(which="major", color='darkgray', linestyle='-', linewidth=.5)
        ax6.grid(which="minor", color='gray'    , linestyle=':', linewidth=.5)
        ax6.yaxis.set_minor_locator(minor_locator)

        ax7 = fig.add_subplot(3, 4, 9)
        ax7.tick_params(axis="both", which="both", right=True, top=True, direction="in", labelsize=12)
        ax7.plot(plot_x, q_values     [test_x_vals[2],:,0], "x-" , c="r"         , ms=5, lw=1, alpha=0.5, label="Estimated")
        ax7.plot(plot_x, bs_values    [test_x_vals[2],:,0], "x-" , c="b"         , ms=5, lw=1, alpha=0.5, label="Bootstrap")
        ax7.plot(plot_x, self.q_target[test_x_vals[2],:,0], "o-" , c="darkorange", ms=5, lw=4, alpha=0.5, label="Target")
        ax7.plot(plot_x, self.q_true  [test_x_vals[2],:,0], ".--", c="gray"      , ms=5, lw=4, alpha=0.5, label="True")
        ax7.grid(which="major", color='darkgray', linestyle='-', linewidth=.5)
        ax7.grid(which="minor", color='gray'    , linestyle=':', linewidth=.5)
        ax7.yaxis.set_minor_locator(minor_locator)
        ax7.set_xlabel("$x$", labelpad=15, fontsize=14)

        ax8 = fig.add_subplot(3, 4, 10)
        ax8.tick_params(axis="both", which="both", right=True, top=True, direction="in", labelsize=12)
        ax8.plot(plot_x, q_values     [test_x_vals[2],:,1], "x-" , c="r"         , ms=5, lw=1, alpha=0.5, label="Estimated")
        ax8.plot(plot_x, bs_values    [test_x_vals[2],:,1], "x-" , c="b"         , ms=5, lw=1, alpha=0.5, label="Bootstrap")
        ax8.plot(plot_x, self.q_target[test_x_vals[2],:,1], "o-" , c="darkorange", ms=5, lw=4, alpha=0.5, label="Target")
        ax8.plot(plot_x, self.q_true  [test_x_vals[2],:,1], ".--", c="gray"      , ms=5, lw=4, alpha=0.5, label="True")
        ax8.grid(which="major", color='darkgray', linestyle='-', linewidth=.5)
        ax8.grid(which="minor", color='gray'    , linestyle=':', linewidth=.5)
        ax8.yaxis.set_minor_locator(minor_locator)
        ax8.set_xlabel("$x$", labelpad=15, fontsize=14)

        ax9 = fig.add_subplot(3, 4, 11)
        ax9.tick_params(axis="both", which="both", right=True, top=True, direction="in", labelsize=12)
        ax9.plot(plot_x, q_values     [test_x_vals[2],:,2], "x-" , c="r"         , ms=5, lw=1, alpha=0.5, label="Estimated")
        ax9.plot(plot_x, bs_values    [test_x_vals[2],:,2], "x-" , c="b"         , ms=5, lw=1, alpha=0.5, label="Bootstrap")
        ax9.plot(plot_x, self.q_target[test_x_vals[2],:,2], "o-" , c="darkorange", ms=5, lw=4, alpha=0.5, label="Target")
        ax9.plot(plot_x, self.q_true  [test_x_vals[2],:,2], ".--", c="gray"      , ms=5, lw=4, alpha=0.5, label="True")
        ax9.grid(which="major", color='darkgray', linestyle='-', linewidth=.5)
        ax9.grid(which="minor", color='gray'    , linestyle=':', linewidth=.5)
        ax9.yaxis.set_minor_locator(minor_locator)
        ax9.set_xlabel("$x$", labelpad=15, fontsize=14)
        
        #  Draw action choices
        ax10 = fig.add_subplot(3, 4, 4)
        ax10.tick_params(axis="both", which="both", right=True, top=True, direction="in", labelsize=12)
        ax10.yaxis.set_label_position("right")
        ax10.yaxis.tick_right()
        ax10.plot(plot_x, self.greedy_dx     [test_x_vals[0],:], "o-" , c="darkorange", ms=5, lw=4, alpha=0.5)
        ax10.plot(plot_x, self.greedy_dx_true[test_x_vals[0],:], ".--", c="gray"      , ms=5, lw=4, alpha=0.5)
        ax10.axhline(0, lw=1, c="k", ls="-")
        ax10.grid(True, which='both')
        ax10.xaxis.set_ticklabels([])
        
        ax11 = fig.add_subplot(3, 4, 8)
        ax11.tick_params(axis="both", which="both", right=True, top=True, direction="in", labelsize=12)
        ax11.yaxis.set_label_position("right")
        ax11.yaxis.tick_right()
        ax11.plot(plot_x, self.greedy_dx     [test_x_vals[1],:], "o-" , c="darkorange", ms=5, lw=4, alpha=0.5)
        ax11.plot(plot_x, self.greedy_dx_true[test_x_vals[1],:], ".--", c="gray"      , ms=5, lw=4, alpha=0.5)
        ax11.axhline(0, lw=1, c="k", ls="-")
        ax11.grid(True, which='both')
        ax11.xaxis.set_ticklabels([])
        
        ax12 = fig.add_subplot(3, 4, 12)
        ax12.tick_params(axis="both", which="both", right=True, top=True, direction="in", labelsize=12)
        ax12.yaxis.set_label_position("right")
        ax12.yaxis.tick_right()
        ax12.plot(plot_x, self.greedy_dx     [test_x_vals[2],:], "o-" , c="darkorange", ms=5, lw=4, alpha=0.5)
        ax12.plot(plot_x, self.greedy_dx_true[test_x_vals[2],:], ".--", c="gray"      , ms=5, lw=4, alpha=0.5)
        ax12.axhline(0, lw=1, c="k", ls="-")
        ax12.grid(True, which='both')
        ax12.set_xlabel("$x$", labelpad=15, fontsize=14)

        #  Draw accompanying plot objects
        ax1.legend(loc=(0.9,1.38), ncol=4, fontsize=20, frameon=False)
        ax1.axhline(0, lw=1, c="k", ls="-")
        ax2.axhline(0, lw=1, c="k", ls="-")
        ax3.axhline(0, lw=1, c="k", ls="-")
        ax4.axhline(0, lw=1, c="k", ls="-")
        ax5.axhline(0, lw=1, c="k", ls="-")
        ax6.axhline(0, lw=1, c="k", ls="-")
        ax7.axhline(0, lw=1, c="k", ls="-")
        ax8.axhline(0, lw=1, c="k", ls="-")
        ax9.axhline(0, lw=1, c="k", ls="-")
        
        ax1.text(0.01, 1.01, "Value function (action = left)" , ha="left", va="bottom", weight="bold", transform=ax1.transAxes, 
                 alpha=0.8, fontsize=12, c="k")
        ax2.text(0.01, 1.01, "Value function (action = stay)" , ha="left", va="bottom", weight="bold", transform=ax2.transAxes, 
                 alpha=0.8, fontsize=12, c="k")
        ax3.text(0.01, 1.01, "Value function (action = right)", ha="left", va="bottom", weight="bold", transform=ax3.transAxes, 
                 alpha=0.8, fontsize=12, c="k")
        ax10.text(0.01, 1.01, "Greedy action choice", ha="left", va="bottom", weight="bold", transform=ax10.transAxes, 
                 alpha=0.8, fontsize=12, c="k")
        
        ax1.text(-0.18, .5, f"$x_t = {test_x_vals[0]}$:" , ha="right", va="top", transform=ax1.transAxes, fontsize=20, c="k")
        ax4.text(-0.18, .5, f"$x_t = {test_x_vals[1]}$:" , ha="right", va="top", transform=ax4.transAxes, fontsize=20, c="k")
        ax7.text(-0.18, .5, f"$x_t = {test_x_vals[2]}$:" , ha="right", va="top", transform=ax7.transAxes, fontsize=20, c="k")

        #  Figure out and set y-axis ranges
        y_min   = np.nanmin([0, np.nanmin(q_values), np.nanmin(bs_values), np.nanmin(self.q_target), np.nanmin(self.q_true)])
        y_max   = np.nanmax([0, np.nanmax(q_values), np.nanmax(bs_values), np.nanmax(self.q_target), np.nanmax(self.q_true)])
        y_range = y_max - y_min
        y_pad   = 0.1
        y_lim   = [y_min - y_pad*y_range, y_max + y_pad*y_range]
        ax1.set_ylim(y_lim)
        ax2.set_ylim(y_lim)
        ax3.set_ylim(y_lim)
        ax4.set_ylim(y_lim)
        ax5.set_ylim(y_lim)
        ax6.set_ylim(y_lim)
        ax7.set_ylim(y_lim)
        ax8.set_ylim(y_lim)
        ax9.set_ylim(y_lim)
        
        #  Format axes
        fig.subplots_adjust(hspace=0.05, wspace=0.05)
        ax1.xaxis.set_ticklabels([])
        ax2.xaxis.set_ticklabels([])
        ax3.xaxis.set_ticklabels([])
        ax4.xaxis.set_ticklabels([])
        ax5.xaxis.set_ticklabels([])
        ax6.xaxis.set_ticklabels([])
        ax2.yaxis.set_ticklabels([])
        ax3.yaxis.set_ticklabels([])
        ax5.yaxis.set_ticklabels([])
        ax6.yaxis.set_ticklabels([])
        ax8.yaxis.set_ticklabels([])
        ax9.yaxis.set_ticklabels([])

        #  Draw text boxes displaying title and num. epochs
        ax1.text(0., 1.42, f"After {self.epoch_idx} epochs", ha="left", va="bottom", weight="bold", transform=ax1.transAxes, fontsize=20)
        ax1.text(0., 1.29, f"batch_size = {self.batch_size}, optimizer_type = {self.optimizer_type:.6}, learning_rate = {self.learning_rate:.6}, clone_after_epochs={self.clone_after_epochs}, act_from_true_greedy_policy={self.act_from_true_greedy_policy}",
                 ha="left", va="bottom", style="italic", transform=ax1.transAxes, fontsize=14)
        ax1.text(0., 1.18, f"priority_weight = {self.priority_weight:.2f}, num_priority_datapoints={self.num_priority_datapoints}, num_step_returns={self.num_step_returns}",
                 ha="left", va="bottom", style="italic", transform=ax1.transAxes, fontsize=14)
        #  Save / show / close
        if save :
            fname = f"{self.top_directory}/value_estimates_epoch{self.epoch_idx}.png"
            generate_directory_for_file_path(fname, print_msg_on_dir_creation=False)
            plt.savefig(fname, bbox_inches="tight", dpi=dpi)
        if show :
            plt.show(fig)
        if close :
            plt.close(fig)

        #  Return figure and axis
        return fig, ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8, ax9
    
    def plot_training_curves(self, save=False, show=False, close=False, dpi=100) :

        fig = plt.figure(figsize=(30,15))
        fig.set_facecolor("white")
        fig.set_alpha(1)

        ax1 = fig.add_subplot(3, 1, 1)
        ax1.grid(True, which='both')
        ax1.tick_params(axis="both", which="both", right=True, top=True, direction="in", labelsize=30)
        ax1.set_title(r"Mean loss per batch", fontsize=30)
        ax1.xaxis.set_ticklabels([])
        ax1.plot(self.epochs_record, self.loss_record, "o-", c="r", lw=2, ms=4)
        ax1.set_yscale("log")

        ax2 = fig.add_subplot(3, 1, 2)
        ax2.grid(True, which='both')
        ax2.tick_params(axis="both", which="both", right=True, top=True, direction="in", labelsize=30)
        ax2.set_title(r"MSE wrt true values", fontsize=30)
        ax2.xaxis.set_ticklabels([])
        ax2.plot(self.epochs_record, self.mse_true_record, "o-", c="r", lw=2, ms=4)
        ax2.set_yscale("log")

        ax3 = fig.add_subplot(3, 1, 3)
        ax3.grid(True, which='both')
        ax3.tick_params(axis="both", which="both", right=True, top=True, direction="in", labelsize=30)
        ax3.set_title(r"Max $|q(s,a)|$ over all batches", fontsize=30)
        ax3.set_xlabel(r"Epoch", labelpad=15, fontsize=30)
        ax3.plot(self.epochs_record, self.max_abs_q_record, "o-", c="r", lw=2, ms=4)
        ax3.axhline(0, ls="--", lw=2, c="gray")
        if np.isfinite(self.max_abs_q_true) :
            ax3.axhline(self.max_abs_q_true, ls="--", lw=2, c="gray")
            ax3.text(0, self.max_abs_q_true, "True maximum", fontsize=20, ha="left", va="top", c="k")

        fig.subplots_adjust(hspace=0.2)

        if save :
            fname = f"{self.top_directory}/training_curve.pdf"
            generate_directory_for_file_path(fname, print_msg_on_dir_creation=False)
            plt.savefig(fname, bbox_inches="tight", dpi=dpi)
        if show :
            plt.show(fig)
        if close :
            plt.close(fig)

        return fig, ax1, ax2, ax3
    
    def print_summary(self, write_to=None) :
        if write_to is None : write_to = sys.stdout
        self.get_summary(write_to=write_to)
        
    def resolve_argument(self, arg_name, arg_default, dtype, config_dict, **kwargs) :
        if arg_name in kwargs :
            arg_val = kwargs.get(arg_name)
        else :
            arg_val = config_dict.get(arg_name, arg_default)
        return dtype(arg_val)
    
    def get_values_from_model(self, q_model) :
        values = np.zeros(shape=(self.num_states, self.num_states, self.num_actions))
        for x_terminal in range(self.num_states) :
            datapoints = np.array([[x_terminal, x]  for x in range(self.num_states)])
            values[x_terminal] = q_model.predict(datapoints)
            values[x_terminal, x_terminal] = np.ones(shape=(self.num_actions,))*np.nan
        return values
    
    def get_bs_values(self) :
        return self.get_values_from_model(self.bs_model)
    
    def get_q_values(self) :
        return self.get_values_from_model(self.q_model)
    
    def run(self, env, initialise=True, verbose=True) :
        
        #   Re-initialise training objects to start training from scratch - must be done on first run() call
        if initialise :
            self._initialise_run(env, verbose=verbose)
        
        #   At start of new training loop, set constants and plot value functions
        num_states, num_actions, start_time = env.num_states, env.num_actions, time.time()
        bs_values, q_values = self.get_bs_values(), self.get_q_values()
        self.plot_value_functions(bs_values, q_values, save=True, close=True)
        
        #   Start epochs loop
        mse_true = np.inf
        while (self.epoch_idx < self.max_epochs or self.max_epochs < 0) and mse_true > self.max_mse_true :

            #   Start epoch printout message
            if verbose :
                sys.stdout.write(f"Epoch {self.epoch_idx+1} / {self.max_epochs}  [t={time.time()-start_time:.2f}s]")

            #   Set bootstrap model
            if self.clone_after_epochs > 0 and self.epoch_idx % self.clone_after_epochs == 0 :
                self.bs_model.set_weights(self.q_model.get_weights()) 
                
            #   Evaluate bootstrap values
            bs_values = self.get_bs_values()
                
            #   Start batches loop
            epoch_losses, max_abs_q_values = [], []
            for batch_idx in range(math.ceil(self.num_train_X/self.batch_size)) :
                
                #   Resolve sample indices to be used for this batch update
                batch_idx_low     = batch_idx*self.batch_size
                batch_idx_high    = min((batch_idx+1)*self.batch_size, self.num_train_X)
                actual_batch_size = batch_idx_high - batch_idx_low
                if actual_batch_size == 0 : continue

                #   Update the current epoch message to keep track of batch progress 
                if verbose :
                    sys.stdout.write(f"\rEpoch {self.epoch_idx+1} / {self.max_epochs} batch indices ({batch_idx_low}, {batch_idx_high}) / {self.num_train_X}  [t={time.time()-start_time:.2f}s]")

                #   Get states to update for this batch
                batch_train_X    = self.train_X[batch_idx_low:batch_idx_high]
                batch_x_terminal = batch_train_X[:,0]
                batch_x          = batch_train_X[:,1]

                #   Evaluate q_target for this batch (and update global monitor value)
                batch_q_target = []
                for x_terminal, x in batch_train_X :
                    model_args   = np.array([[x_terminal, xp]  for xp in range(num_states)])
                    q_model_vals = self.q_model.predict([model_args])
                    q_a          = []
                    values_to_act_from = q_model_vals
                    if self.act_from_true_greedy_policy : 
                        values_to_act_from = self.q_true[x_terminal]
                    for a_idx in range(num_actions) :
                        g, x_p = env.perform_action(x_terminal, x, a_idx)
                        step_y = env.gamma
                        for step_idx in range(self.num_step_returns-1) :
                            if x_p != x_terminal :
                                a_idx_p = np.argmax(values_to_act_from[x_p])
                                r, x_p  = env.perform_action(x_terminal, x_p, a_idx_p)
                                g      += step_y * r
                            step_y *= env.gamma
                        if x_p != x_terminal :
                            a_idx_p  = np.argmax(values_to_act_from[x_p])
                            g       += step_y * bs_values[x_terminal, x_p, a_idx_p]
                        q_a.append(g)
                        self.q_target[x_terminal, x, a_idx] = g
                    batch_q_target.append(q_a)
                batch_q_target = tf.constant(batch_q_target)

                #   Apply gradient updates and store monitor values
                with tf.GradientTape() as tape:
                    batch_q_model = self.q_model(batch_train_X, training=True)
                    batch_loss    = self.loss_fcn(batch_q_target, batch_q_model)
                    if self.priority_weight > 0 and self.priority_weight <= 1 :
                        priority_q    = self.q_model(self.priority_X, training=True)
                        priority_q    = tf.gather_nd(priority_q, indices=self.priority_a_idcs)
                        priority_loss = self.loss_fcn(self.priority_values, priority_q)
                        train_loss    = self.priority_weight * priority_loss + (1 - self.priority_weight) * batch_loss
                    else :
                        train_loss = batch_loss
                    grads = tape.gradient(train_loss, self.q_model.trainable_weights)
                    self.optimizer.apply_gradients(zip(grads, self.q_model.trainable_weights))
                    epoch_losses.append(train_loss.numpy())
                    max_abs_q_values.append(np.fabs(batch_q_model.numpy()).max())
                    
            #   Calculate post-epoch MSE wrt true values, and other monitors
            q_values = self.get_q_values()
            for x_terminal in range(env.num_states) :
                for x in range(env.num_states) :
                    if x == x_terminal :
                        self.greedy_dx[x_terminal, x] = np.nan
                        continue
                    self.greedy_dx[x_terminal, x] = env.actions_dx[np.argmax(q_values[x_terminal,x])]
            mse_true = get_mse(q_values, self.q_true)
            epoch_mean_loss, epoch_max_abs_q = np.mean(epoch_losses), np.max(max_abs_q_values)
            
            #   Store monitor values
            self.epochs_record   .append(self.epoch_idx)
            self.loss_record     .append(epoch_mean_loss)
            self.mse_true_record .append(mse_true)
            self.max_abs_q_record.append(epoch_max_abs_q)
            
            #   End print line
            if verbose :
                sys.stdout.write(f"\rEpoch {self.epoch_idx+1} / {self.max_epochs}  [t={time.time()-start_time:.2f}s]  <loss = {epoch_mean_loss:.5f}, mse_true = {mse_true:.4f} / {self.max_mse_true:.4f}, max_abs_q = {epoch_max_abs_q:.1f}>".ljust(100)+"\n")
    
            #   Update number of completed epochs
            self.epoch_idx += 1
        
            #   Plot value function
            if self.plot_estimate_after_epochs > 0 and self.epoch_idx % self.plot_estimate_after_epochs == 0 :
                self.plot_value_functions(bs_values, q_values, save=True, close=True)

            #   Plot training curves
            if self.plot_monitors_after_epochs > 0 and self.epoch_idx % self.plot_monitors_after_epochs == 0 :
                self.plot_training_curves(save=True, close=True)

            #   Save objects
            if self.save_objects_after_epochs > 0 and self.epoch_idx % self.save_objects_after_epochs == 0 :
                self.save_run_progress()
        
        #   Make sure final plots and objects are saved
        self.plot_value_functions(bs_values, q_values, save=True, close=True)
        self.plot_training_curves(save=True, close=True)
        self.save_run_progress()
    
    def save_run_progress(self, verbose=True) :
        fname   = f"{self.top_directory}/saved_objects.pickle"
        to_save = self.export_run_config_dict()
        to_save["run:epoch_idx"]               = self.epoch_idx
        to_save["run:epochs_record"]           = self.epochs_record
        to_save["run:loss_record"]             = self.loss_record
        to_save["run:max_abs_q_record"]        = self.max_abs_q_record
        to_save["run:mse_true_record"]         = self.mse_true_record
        to_save["run:train_X" ]                = self.train_X
        to_save["run:num_train_X"]             = self.num_train_X
        to_save["run:num_states"]              = self.num_states
        to_save["run:num_actions"]             = self.num_actions
        to_save["run:q_true"]                  = self.q_true
        to_save["run:q_target"]                = self.q_target
        to_save["run:max_abs_q_true"]          = self.max_abs_q_true
        to_save["run:greedy_dx"]               = self.greedy_dx
        to_save["run:greedy_dx_true"]          = self.greedy_dx_true
        to_save["run:priority_X"]              = self.priority_X
        to_save["run:priority_a_idcs"]         = self.priority_a_idcs
        to_save["run:priority_values"]         = self.priority_values
        to_save["run:num_priority_datapoints"] = self.num_priority_datapoints
        to_save["general:max_mse_true"               ]  = self.max_mse_true
        to_save["general:max_epochs"                 ]  = self.max_epochs
        to_save["general:batch_size"                 ]  = self.batch_size
        to_save["general:optimizer_type"             ]  = self.optimizer_type
        to_save["general:learning_rate"              ]  = self.learning_rate
        to_save["general:plot_estimate_after_epochs" ]  = self.plot_estimate_after_epochs
        to_save["general:plot_monitors_after_epochs" ]  = self.plot_monitors_after_epochs
        to_save["general:save_objects_after_epochs"  ]  = self.save_objects_after_epochs
        to_save["general:act_from_true_greedy_policy"]  = self.act_from_true_greedy_policy
        to_save["general:clone_after_epochs"         ]  = self.clone_after_epochs
        to_save["general:priority_weight"            ]  = self.priority_weight
        to_save["general:num_step_returns"           ]  = self.num_step_returns
        to_save["general:run_tag"                    ]  = self.run_tag
        to_save["general:run_idx"                    ]  = self.run_idx
        generate_directory_for_file_path(fname, print_msg_on_dir_creation=verbose)
        pickle.dump(to_save, open(fname,"wb"))
        tf_log_level = tf.get_logger().level
        tf.get_logger().setLevel('WARNING')
        self.q_model .save(f"{self.top_directory}/q_model")
        self.bs_model.save(f"{self.top_directory}/bs_model")
        tf.get_logger().setLevel(tf_log_level)
        
    def set_derived_constants(self) :
        self.top_directory = f"figures/Helicopter_1D_random_finish/{self.run_tag}/experiment_{self.run_idx}"
    
    def _initialise_keras_objects(self, env) :
        self.loss_fcn, self.q_model, self.bs_model, self.optimizer = initialise_keras_objects(
                                                env, self.learning_rate, optimizer_type=self.optimizer_type, 
                                                tag=f"{self.run_tag}_run{self.run_idx}")
    
    def _initialise_monitor_records(self) :
        self.epochs_record    = []
        self.loss_record      = []
        self.max_abs_q_record = []
        self.mse_true_record  = []
        
    def _initialise_run(self, env, verbose=True) :
        self.epoch_idx   = 0
        self.num_states  = env.num_states
        self.num_actions = env.num_actions
        self._initialise_keras_objects(env)
        self._initialise_monitor_records()
        self._intialise_q_true(env)
        self._intialise_priority_datapoints(env)
        self.create_config(env, verbose=verbose)
        train_X          = env.get_all_valid_x_pairs()
        shuffle_idcs     = np.arange(len(train_X))
        np.random.shuffle(shuffle_idcs)
        self.train_X     = train_X[shuffle_idcs]
        self.num_train_X = len(self.train_X)
        self.q_target    = np.zeros_like(self.q_true)
        self.greedy_dx   = np.zeros(shape=(self.num_states, self.num_states))
        for x_terminal in range(self.num_states) :
            datapoints = np.array([[x_terminal, x]  for x in range(self.num_states)])
            self.q_target [x_terminal, x_terminal] = np.ones(shape=(self.num_actions,))*np.nan
            self.greedy_dx[x_terminal, x_terminal] = np.nan
            
    def _intialise_priority_datapoints(self, env) :
        priority_X, priority_a_idcs, priority_values = [], [], []
        for x_terminal in range(self.num_states) :
            possible_x = []
            if x_terminal > 0 : 
                possible_x.append(x_terminal - 1)
            if x_terminal < self.num_states - 1 : 
                possible_x.append(x_terminal + 1)
            for x in possible_x :
                a_idx = np.argmax(self.q_true[x_terminal, x])
                r, _  = env.perform_action(x_terminal, x, a_idx)
                priority_X     .append([x_terminal, x])
                priority_a_idcs.append([0, a_idx])
                priority_values.append(r)
        self.priority_X      = tf.constant(priority_X)
        self.priority_a_idcs = tf.constant(priority_a_idcs)
        self.priority_values = tf.constant(priority_values)
        self.num_priority_datapoints = len(priority_values)
        
    def _intialise_q_true(self, env) :
        q_true = np.zeros(shape=(self.num_states, self.num_states, self.num_actions))
        self.greedy_dx_true = np.zeros(shape=(self.num_states, self.num_states))
        for x_terminal in range(self.num_states) :
            q_true[x_terminal] = env.get_q_true(x_terminal)
            for x in range(self.num_states) :
                if x == x_terminal :
                    self.greedy_dx_true[x_terminal, x] = np.nan
                    continue
                self.greedy_dx_true[x_terminal, x] = env.actions_dx[np.argmax(q_true[x_terminal,x])]
        self.q_true = q_true
        self.max_abs_q_true = np.nanmax(np.fabs(q_true))
        

In [14]:
'''
env = Environment(num_states=20, r_per_turn=-1., r_per_dx=0., r_per_b=-1., gamma=.99)
exp = Experiment(max_epochs=300, batch_size=10, learning_rate=1e-2, clone_after_epochs=2, max_mse_true=0.01,
                 plot_estimate_after_epochs=2, plot_monitors_after_epochs=4, save_objects_after_epochs=4,
                 act_from_true_greedy_policy=True, run_tag="act_from_true_policy")

exp.is_well_configured(debug_to=sys.stdout, enforce_with_exception=True)

exp.run(env)
'''

'\nenv = Environment(num_states=20, r_per_turn=-1., r_per_dx=0., r_per_b=-1., gamma=.99)\nexp = Experiment(max_epochs=300, batch_size=10, learning_rate=1e-2, clone_after_epochs=2, max_mse_true=0.01,\n                 plot_estimate_after_epochs=2, plot_monitors_after_epochs=4, save_objects_after_epochs=4,\n                 act_from_true_greedy_policy=True, run_tag="act_from_true_policy")\n\nexp.is_well_configured(debug_to=sys.stdout, enforce_with_exception=True)\n\nexp.run(env)\n'

In [15]:
'''
env = Environment(num_states=20, r_per_turn=-1., r_per_dx=0., r_per_b=-1., gamma=.99)
exp = Experiment(max_epochs=300, batch_size=10, learning_rate=1e-2, clone_after_epochs=2, max_mse_true=0.01,
                 plot_estimate_after_epochs=2, plot_monitors_after_epochs=4, save_objects_after_epochs=4,
                 run_tag="baseline")

exp.is_well_configured(debug_to=sys.stdout, enforce_with_exception=True)

exp.run(env)
'''

'\nenv = Environment(num_states=20, r_per_turn=-1., r_per_dx=0., r_per_b=-1., gamma=.99)\nexp = Experiment(max_epochs=300, batch_size=10, learning_rate=1e-2, clone_after_epochs=2, max_mse_true=0.01,\n                 plot_estimate_after_epochs=2, plot_monitors_after_epochs=4, save_objects_after_epochs=4,\n                 run_tag="baseline")\n\nexp.is_well_configured(debug_to=sys.stdout, enforce_with_exception=True)\n\nexp.run(env)\n'

In [18]:

env = Environment(num_states=20, r_per_turn=-1., r_per_dx=0., r_per_b=-1., gamma=.99)
exp = Experiment(max_epochs=300, batch_size=10, learning_rate=1e-2, clone_after_epochs=2, max_mse_true=0.02,
                 plot_estimate_after_epochs=2, plot_monitors_after_epochs=4, save_objects_after_epochs=4,
                 priority_weight=0.3, run_tag="priority_0p3")

exp.is_well_configured(debug_to=sys.stdout, enforce_with_exception=True)

exp.run(env)


Environment config:
    num_states  | 20
    x_min       | 0
    x_max       | 19
    num_actions | 3
    actions_dx  | [-1  0  1]
    r_per_turn  | -1.0
    r_per_dx    | 0.0
    r_per_b     | -1.0
    gamma       | 0.99
    Ns*Ns*Na    | 1200
Experiment config:
    max_mse_true                | 0.01
    max_epochs                  | 300
    batch_size                  | 10
    optimizer_type              | sgd
    learning_rate               | 0.01
    plot_estimate_after_epochs  | 2
    plot_monitors_after_epochs  | 4
    save_objects_after_epochs   | 4
    act_from_true_greedy_policy | False
    clone_after_epochs          | 2
    priority_weight             | 0.3
    num_step_returns            | 1
    run_tag                     | priority_0p3
    run_idx                     | 0
    num_priority_datapoints     | 38

Epoch 1 / 300  [t=13.60s]  <loss = 0.29985, mse_true = 62.6168 / 0.0100, max_abs_q = 1.1>          
Epoch 2 / 300  [t=24.63s]  <loss = 0.04915, mse_true = 59.7475 / 0

2022-08-01 13:23:49.297569: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.


Epoch 5 / 300  [t=63.97s]  <loss = 0.24920, mse_true = 42.0821 / 0.0100, max_abs_q = 3.5>          
Epoch 6 / 300  [t=76.52s]  <loss = 0.17699, mse_true = 41.8569 / 0.0100, max_abs_q = 3.6>          
Epoch 7 / 300  [t=89.41s]  <loss = 0.29791, mse_true = 35.9510 / 0.0100, max_abs_q = 4.7>          
Epoch 8 / 300  [t=100.57s]  <loss = 0.25427, mse_true = 35.8848 / 0.0100, max_abs_q = 4.8>         
Epoch 9 / 300  [t=113.75s]  <loss = 0.37638, mse_true = 31.0117 / 0.0100, max_abs_q = 5.9>         
Epoch 10 / 300  [t=124.68s]  <loss = 0.34378, mse_true = 30.9475 / 0.0100, max_abs_q = 6.0>        
Epoch 11 / 300  [t=136.70s]  <loss = 0.48135, mse_true = 27.0123 / 0.0100, max_abs_q = 7.0>        
Epoch 12 / 300  [t=147.37s]  <loss = 0.45380, mse_true = 26.9281 / 0.0100, max_abs_q = 7.0>        
Epoch 13 / 300  [t=160.44s]  <loss = 0.61204, mse_true = 23.9048 / 0.0100, max_abs_q = 8.1>        
Epoch 14 / 300  [t=171.39s]  <loss = 0.58677, mse_true = 23.7627 / 0.0100, max_abs_q = 8.1>        


Epoch 87 / 300  [t=1032.89s]  <loss = 3.62145, mse_true = 108.9845 / 0.0100, max_abs_q = 36.4>     
Epoch 88 / 300  [t=1043.33s]  <loss = 3.60790, mse_true = 107.7175 / 0.0100, max_abs_q = 36.4>     
Epoch 89 / 300  [t=1057.50s]  <loss = 3.22810, mse_true = 99.5570 / 0.0100, max_abs_q = 35.6>      
Epoch 90 / 300  [t=1067.96s]  <loss = 3.21763, mse_true = 98.7500 / 0.0100, max_abs_q = 35.6>      
Epoch 91 / 300  [t=1079.81s]  <loss = 2.89698, mse_true = 91.3544 / 0.0100, max_abs_q = 34.9>      
Epoch 92 / 300  [t=1091.69s]  <loss = 2.88086, mse_true = 90.7248 / 0.0100, max_abs_q = 35.0>      
Epoch 93 / 300  [t=1104.46s]  <loss = 2.60722, mse_true = 83.9524 / 0.0100, max_abs_q = 34.4>      
Epoch 94 / 300  [t=1115.02s]  <loss = 2.59045, mse_true = 83.5306 / 0.0100, max_abs_q = 34.4>      
Epoch 95 / 300  [t=1128.34s]  <loss = 2.35063, mse_true = 77.7162 / 0.0100, max_abs_q = 33.9>      
Epoch 96 / 300  [t=1138.79s]  <loss = 2.33254, mse_true = 77.5541 / 0.0100, max_abs_q = 34.0>      


Epoch 169 / 300  [t=2032.16s]  <loss = 0.15652, mse_true = 1.8409 / 0.0100, max_abs_q = 17.6>      
Epoch 170 / 300  [t=2042.74s]  <loss = 0.15444, mse_true = 1.8228 / 0.0100, max_abs_q = 17.5>      
Epoch 171 / 300  [t=2054.48s]  <loss = 0.15609, mse_true = 1.7934 / 0.0100, max_abs_q = 17.6>      
Epoch 172 / 300  [t=2065.08s]  <loss = 0.15422, mse_true = 1.7768 / 0.0100, max_abs_q = 17.5>      
Epoch 173 / 300  [t=2077.76s]  <loss = 0.15423, mse_true = 1.7460 / 0.0100, max_abs_q = 17.6>      
Epoch 174 / 300  [t=2088.39s]  <loss = 0.15249, mse_true = 1.7303 / 0.0100, max_abs_q = 17.5>      
Epoch 175 / 300  [t=2100.16s]  <loss = 0.15104, mse_true = 1.6979 / 0.0100, max_abs_q = 17.7>      
Epoch 176 / 300  [t=2113.73s]  <loss = 0.14939, mse_true = 1.6828 / 0.0100, max_abs_q = 17.6>      
Epoch 177 / 300  [t=2126.50s]  <loss = 0.14671, mse_true = 1.6485 / 0.0100, max_abs_q = 17.8>      
Epoch 178 / 300  [t=2137.51s]  <loss = 0.14508, mse_true = 1.6339 / 0.0100, max_abs_q = 17.7>      


Epoch 251 / 300  [t=3032.38s]  <loss = 0.03436, mse_true = 0.1946 / 0.0100, max_abs_q = 19.3>      
Epoch 252 / 300  [t=3044.13s]  <loss = 0.03395, mse_true = 0.1933 / 0.0100, max_abs_q = 19.3>      
Epoch 253 / 300  [t=3058.13s]  <loss = 0.03385, mse_true = 0.1843 / 0.0100, max_abs_q = 19.3>      
Epoch 254 / 300  [t=3068.70s]  <loss = 0.03345, mse_true = 0.1830 / 0.0100, max_abs_q = 19.3>      
Epoch 255 / 300  [t=3080.55s]  <loss = 0.03339, mse_true = 0.1748 / 0.0100, max_abs_q = 19.3>      
Epoch 256 / 300  [t=3095.69s]  <loss = 0.03300, mse_true = 0.1736 / 0.0100, max_abs_q = 19.3>      
Epoch 257 / 300  [t=3108.42s]  <loss = 0.03298, mse_true = 0.1661 / 0.0100, max_abs_q = 19.3>      
Epoch 258 / 300  [t=3118.87s]  <loss = 0.03260, mse_true = 0.1649 / 0.0100, max_abs_q = 19.3>      
Epoch 259 / 300  [t=3130.71s]  <loss = 0.03261, mse_true = 0.1581 / 0.0100, max_abs_q = 19.3>      
Epoch 260 / 300  [t=3141.16s]  <loss = 0.03224, mse_true = 0.1569 / 0.0100, max_abs_q = 19.3>      


In [19]:

env = Environment(num_states=20, r_per_turn=-1., r_per_dx=0., r_per_b=-1., gamma=.99)
exp = Experiment(max_epochs=400, batch_size=10, learning_rate=1e-2, clone_after_epochs=2, max_mse_true=0.02,
                 plot_estimate_after_epochs=2, plot_monitors_after_epochs=4, save_objects_after_epochs=4,
                 run_tag="multi_step", num_step_returns=2)

exp.is_well_configured(debug_to=sys.stdout, enforce_with_exception=True)

exp.run(env)


Environment config:
    num_states  | 20
    x_min       | 0
    x_max       | 19
    num_actions | 3
    actions_dx  | [-1  0  1]
    r_per_turn  | -1.0
    r_per_dx    | 0.0
    r_per_b     | -1.0
    gamma       | 0.99
    Ns*Ns*Na    | 1200
Experiment config:
    max_mse_true                | 0.01
    max_epochs                  | 400
    batch_size                  | 10
    optimizer_type              | sgd
    learning_rate               | 0.01
    plot_estimate_after_epochs  | 2
    plot_monitors_after_epochs  | 4
    save_objects_after_epochs   | 4
    act_from_true_greedy_policy | False
    clone_after_epochs          | 2
    priority_weight             | 0.0
    num_step_returns            | 2
    run_tag                     | multi_step
    run_idx                     | 0
    num_priority_datapoints     | 38

Directory figures/Helicopter_1D_random_finish/multi_step created
Directory figures/Helicopter_1D_random_finish/multi_step/experiment_0 created
Epoch 1 / 400  [t=13.27s]

Epoch 70 / 400  [t=817.19s]  <loss = 0.43233, mse_true = 9.0530 / 0.0100, max_abs_q = 22.1>        
Epoch 71 / 400  [t=828.42s]  <loss = 0.31009, mse_true = 5.8524 / 0.0100, max_abs_q = 22.6>        
Epoch 72 / 400  [t=838.55s]  <loss = 0.24161, mse_true = 6.0157 / 0.0100, max_abs_q = 23.2>        
Epoch 73 / 400  [t=852.05s]  <loss = 0.25300, mse_true = 3.6109 / 0.0100, max_abs_q = 22.9>        
Epoch 74 / 400  [t=862.21s]  <loss = 0.19547, mse_true = 3.7099 / 0.0100, max_abs_q = 23.8>        
Epoch 75 / 400  [t=874.84s]  <loss = 0.22275, mse_true = 2.0558 / 0.0100, max_abs_q = 23.3>        
Epoch 76 / 400  [t=884.76s]  <loss = 0.17777, mse_true = 2.1060 / 0.0100, max_abs_q = 23.8>        
Epoch 77 / 400  [t=896.73s]  <loss = 0.20898, mse_true = 1.0712 / 0.0100, max_abs_q = 23.3>        
Epoch 78 / 400  [t=906.69s]  <loss = 0.17119, mse_true = 1.0912 / 0.0100, max_abs_q = 23.2>        
Epoch 79 / 400  [t=919.16s]  <loss = 0.20250, mse_true = 0.5312 / 0.0100, max_abs_q = 22.8>        


Epoch 152 / 400  [t=1754.37s]  <loss = 0.03374, mse_true = 0.0281 / 0.0100, max_abs_q = 19.0>      
Epoch 153 / 400  [t=1768.91s]  <loss = 0.03299, mse_true = 0.0271 / 0.0100, max_abs_q = 19.0>      
Epoch 154 / 400  [t=1778.89s]  <loss = 0.03279, mse_true = 0.0269 / 0.0100, max_abs_q = 19.0>      
Epoch 155 / 400  [t=1789.95s]  <loss = 0.03209, mse_true = 0.0260 / 0.0100, max_abs_q = 19.0>      
Epoch 156 / 400  [t=1799.94s]  <loss = 0.03190, mse_true = 0.0258 / 0.0100, max_abs_q = 19.0>      
Epoch 157 / 400  [t=1814.35s]  <loss = 0.03123, mse_true = 0.0250 / 0.0100, max_abs_q = 19.0>      
Epoch 158 / 400  [t=1824.28s]  <loss = 0.03105, mse_true = 0.0248 / 0.0100, max_abs_q = 19.0>      
Epoch 159 / 400  [t=1835.32s]  <loss = 0.03042, mse_true = 0.0241 / 0.0100, max_abs_q = 19.0>      
Epoch 160 / 400  [t=1845.25s]  <loss = 0.03024, mse_true = 0.0239 / 0.0100, max_abs_q = 19.0>      
Epoch 161 / 400  [t=1857.18s]  <loss = 0.02964, mse_true = 0.0232 / 0.0100, max_abs_q = 19.0>      


In [21]:

env = Environment(num_states=20, r_per_turn=-1., r_per_dx=0., r_per_b=-1., gamma=.99)
exp = Experiment(max_epochs=300, batch_size=10, learning_rate=1e-2, clone_after_epochs=1, max_mse_true=0.02,
                 plot_estimate_after_epochs=2, plot_monitors_after_epochs=4, save_objects_after_epochs=4,
                 run_tag="clone_1")

exp.is_well_configured(debug_to=sys.stdout, enforce_with_exception=True)

exp.run(env)


Environment config:
    num_states  | 20
    x_min       | 0
    x_max       | 19
    num_actions | 3
    actions_dx  | [-1  0  1]
    r_per_turn  | -1.0
    r_per_dx    | 0.0
    r_per_b     | -1.0
    gamma       | 0.99
    Ns*Ns*Na    | 1200
Experiment config:
    max_mse_true                | 0.02
    max_epochs                  | 300
    batch_size                  | 10
    optimizer_type              | sgd
    learning_rate               | 0.01
    plot_estimate_after_epochs  | 2
    plot_monitors_after_epochs  | 4
    save_objects_after_epochs   | 4
    act_from_true_greedy_policy | False
    clone_after_epochs          | 1
    priority_weight             | 0.0
    num_step_returns            | 1
    run_tag                     | clone_1
    run_idx                     | 0
    num_priority_datapoints     | 38

Epoch 1 / 300  [t=21.95s]  <loss = 0.33592, mse_true = 61.0710 / 0.0200, max_abs_q = 1.0>          
Epoch 2 / 300  [t=34.96s]  <loss = 0.23342, mse_true = 49.7948 / 0.0200

Epoch 71 / 300  [t=865.60s]  <loss = 2.19457, mse_true = 33.1731 / 0.0200, max_abs_q = 31.4>       
Epoch 72 / 300  [t=878.15s]  <loss = 2.19493, mse_true = 32.5153 / 0.0200, max_abs_q = 30.7>       
Epoch 73 / 300  [t=890.51s]  <loss = 2.19033, mse_true = 32.1105 / 0.0200, max_abs_q = 30.1>       
Epoch 74 / 300  [t=900.59s]  <loss = 2.17853, mse_true = 31.8997 / 0.0200, max_abs_q = 29.6>       
Epoch 75 / 300  [t=913.35s]  <loss = 2.15888, mse_true = 31.8146 / 0.0200, max_abs_q = 29.3>       
Epoch 76 / 300  [t=923.33s]  <loss = 2.12822, mse_true = 31.8174 / 0.0200, max_abs_q = 29.0>       
Epoch 77 / 300  [t=936.97s]  <loss = 2.08699, mse_true = 31.8411 / 0.0200, max_abs_q = 28.8>       
Epoch 78 / 300  [t=946.94s]  <loss = 2.03225, mse_true = 31.8007 / 0.0200, max_abs_q = 28.8>       
Epoch 79 / 300  [t=958.21s]  <loss = 1.95872, mse_true = 31.6420 / 0.0200, max_abs_q = 28.7>       
Epoch 80 / 300  [t=968.16s]  <loss = 1.88097, mse_true = 31.3594 / 0.0200, max_abs_q = 28.7>       


Epoch 153 / 300  [t=1808.48s]  <loss = 0.21733, mse_true = 2.1153 / 0.0200, max_abs_q = 21.5>      
Epoch 154 / 300  [t=1818.54s]  <loss = 0.21322, mse_true = 2.0818 / 0.0200, max_abs_q = 21.5>      
Epoch 155 / 300  [t=1829.61s]  <loss = 0.20919, mse_true = 2.0460 / 0.0200, max_abs_q = 21.5>      
Epoch 156 / 300  [t=1839.54s]  <loss = 0.20527, mse_true = 2.0084 / 0.0200, max_abs_q = 21.5>      
Epoch 157 / 300  [t=1851.62s]  <loss = 0.20146, mse_true = 1.9694 / 0.0200, max_abs_q = 21.5>      
Epoch 158 / 300  [t=1861.59s]  <loss = 0.19776, mse_true = 1.9295 / 0.0200, max_abs_q = 21.5>      
Epoch 159 / 300  [t=1875.63s]  <loss = 0.19418, mse_true = 1.8889 / 0.0200, max_abs_q = 21.5>      
Epoch 160 / 300  [t=1885.57s]  <loss = 0.19071, mse_true = 1.8479 / 0.0200, max_abs_q = 21.5>      
Epoch 161 / 300  [t=1897.59s]  <loss = 0.18736, mse_true = 1.8068 / 0.0200, max_abs_q = 21.5>      
Epoch 162 / 300  [t=1907.50s]  <loss = 0.18413, mse_true = 1.7659 / 0.0200, max_abs_q = 21.5>      


Epoch 235 / 300  [t=2757.54s]  <loss = 0.07385, mse_true = 0.4535 / 0.0200, max_abs_q = 19.7>      
Epoch 236 / 300  [t=2767.46s]  <loss = 0.07302, mse_true = 0.4484 / 0.0200, max_abs_q = 19.7>      
Epoch 237 / 300  [t=2779.51s]  <loss = 0.07220, mse_true = 0.4434 / 0.0200, max_abs_q = 19.7>      
Epoch 238 / 300  [t=2789.46s]  <loss = 0.07139, mse_true = 0.4383 / 0.0200, max_abs_q = 19.7>      
Epoch 239 / 300  [t=2804.55s]  <loss = 0.07058, mse_true = 0.4333 / 0.0200, max_abs_q = 19.7>      
Epoch 240 / 300  [t=2814.56s]  <loss = 0.06978, mse_true = 0.4282 / 0.0200, max_abs_q = 19.7>      
Epoch 241 / 300  [t=2826.64s]  <loss = 0.06898, mse_true = 0.4231 / 0.0200, max_abs_q = 19.7>      
Epoch 242 / 300  [t=2836.61s]  <loss = 0.06819, mse_true = 0.4180 / 0.0200, max_abs_q = 19.7>      
Epoch 243 / 300  [t=2847.62s]  <loss = 0.06740, mse_true = 0.4129 / 0.0200, max_abs_q = 19.7>      
Epoch 244 / 300  [t=2857.55s]  <loss = 0.06661, mse_true = 0.4078 / 0.0200, max_abs_q = 19.6>      


In [22]:

env = Environment(num_states=20, r_per_turn=-1., r_per_dx=1., r_per_b=-1., gamma=.99)
exp = Experiment(max_epochs=300, batch_size=10, learning_rate=1e-2, clone_after_epochs=2, max_mse_true=0.02,
                 plot_estimate_after_epochs=2, plot_monitors_after_epochs=4, save_objects_after_epochs=4,
                 run_tag="easy_function")

exp.is_well_configured(debug_to=sys.stdout, enforce_with_exception=True)

exp.run(env)



Environment config:
    num_states  | 20
    x_min       | 0
    x_max       | 19
    num_actions | 3
    actions_dx  | [-1  0  1]
    r_per_turn  | -1.0
    r_per_dx    | 1.0
    r_per_b     | -1.0
    gamma       | 0.99
    Ns*Ns*Na    | 1200
Experiment config:
    max_mse_true                | 0.02
    max_epochs                  | 300
    batch_size                  | 10
    optimizer_type              | sgd
    learning_rate               | 0.01
    plot_estimate_after_epochs  | 2
    plot_monitors_after_epochs  | 4
    save_objects_after_epochs   | 4
    act_from_true_greedy_policy | False
    clone_after_epochs          | 2
    priority_weight             | 0.0
    num_step_returns            | 1
    run_tag                     | easy_function
    run_idx                     | 0
    num_priority_datapoints     | 38

Directory figures/Helicopter_1D_random_finish/easy_function created
Directory figures/Helicopter_1D_random_finish/easy_function/experiment_0 created
Epoch 1 / 300  [

Epoch 70 / 300  [t=789.23s]  <loss = 1.96965, mse_true = 43.3051 / 0.0200, max_abs_q = 11.7>       
Epoch 71 / 300  [t=800.88s]  <loss = 1.95317, mse_true = 43.5415 / 0.0200, max_abs_q = 11.8>       
Epoch 72 / 300  [t=810.84s]  <loss = 1.93844, mse_true = 43.6484 / 0.0200, max_abs_q = 12.0>       
Epoch 73 / 300  [t=823.36s]  <loss = 1.90535, mse_true = 43.5312 / 0.0200, max_abs_q = 12.0>       
Epoch 74 / 300  [t=833.33s]  <loss = 1.88458, mse_true = 43.7073 / 0.0200, max_abs_q = 12.2>       
Epoch 75 / 300  [t=845.06s]  <loss = 1.83621, mse_true = 43.3007 / 0.0200, max_abs_q = 12.3>       
Epoch 76 / 300  [t=855.02s]  <loss = 1.81534, mse_true = 43.4987 / 0.0200, max_abs_q = 12.4>       
Epoch 77 / 300  [t=867.53s]  <loss = 1.74533, mse_true = 42.8184 / 0.0200, max_abs_q = 12.5>       
Epoch 78 / 300  [t=877.46s]  <loss = 1.72076, mse_true = 43.0429 / 0.0200, max_abs_q = 12.6>       
Epoch 79 / 300  [t=889.23s]  <loss = 1.63261, mse_true = 42.0727 / 0.0200, max_abs_q = 12.7>       


Epoch 152 / 300  [t=1691.55s]  <loss = 0.02867, mse_true = 0.1430 / 0.0200, max_abs_q = 3.2>       
Epoch 153 / 300  [t=1703.02s]  <loss = 0.02986, mse_true = 0.1442 / 0.0200, max_abs_q = 3.2>       
Epoch 154 / 300  [t=1714.63s]  <loss = 0.02946, mse_true = 0.1432 / 0.0200, max_abs_q = 3.2>       
Epoch 155 / 300  [t=1725.28s]  <loss = 0.03059, mse_true = 0.1471 / 0.0200, max_abs_q = 3.2>       
Epoch 156 / 300  [t=1735.22s]  <loss = 0.03023, mse_true = 0.1464 / 0.0200, max_abs_q = 3.2>       
Epoch 157 / 300  [t=1748.45s]  <loss = 0.03127, mse_true = 0.1524 / 0.0200, max_abs_q = 3.2>       
Epoch 158 / 300  [t=1758.57s]  <loss = 0.03093, mse_true = 0.1518 / 0.0200, max_abs_q = 3.2>       
Epoch 159 / 300  [t=1769.50s]  <loss = 0.03186, mse_true = 0.1594 / 0.0200, max_abs_q = 3.2>       
Epoch 160 / 300  [t=1780.72s]  <loss = 0.03155, mse_true = 0.1589 / 0.0200, max_abs_q = 3.1>       
Epoch 161 / 300  [t=1794.95s]  <loss = 0.03237, mse_true = 0.1678 / 0.0200, max_abs_q = 3.2>       


Epoch 234 / 300  [t=2637.50s]  <loss = 0.01586, mse_true = 0.1474 / 0.0200, max_abs_q = 3.3>       
Epoch 235 / 300  [t=2648.86s]  <loss = 0.01563, mse_true = 0.1425 / 0.0200, max_abs_q = 3.3>       
Epoch 236 / 300  [t=2659.52s]  <loss = 0.01556, mse_true = 0.1423 / 0.0200, max_abs_q = 3.3>       
Epoch 237 / 300  [t=2673.70s]  <loss = 0.01533, mse_true = 0.1377 / 0.0200, max_abs_q = 3.3>       
Epoch 238 / 300  [t=2684.39s]  <loss = 0.01527, mse_true = 0.1375 / 0.0200, max_abs_q = 3.3>       
Epoch 239 / 300  [t=2695.25s]  <loss = 0.01506, mse_true = 0.1331 / 0.0200, max_abs_q = 3.2>       
Epoch 240 / 300  [t=2705.75s]  <loss = 0.01499, mse_true = 0.1330 / 0.0200, max_abs_q = 3.2>       
Epoch 241 / 300  [t=2717.73s]  <loss = 0.01479, mse_true = 0.1288 / 0.0200, max_abs_q = 3.2>       
Epoch 242 / 300  [t=2731.78s]  <loss = 0.01473, mse_true = 0.1286 / 0.0200, max_abs_q = 3.2>       
Epoch 243 / 300  [t=2742.82s]  <loss = 0.01454, mse_true = 0.1246 / 0.0200, max_abs_q = 3.2>       


In [24]:

env = Environment(num_states=20, r_per_turn=-1., r_per_dx=0., r_per_b=-1., gamma=.99)
exp = Experiment(max_epochs=600, batch_size=10, learning_rate=1e-2, clone_after_epochs=6, max_mse_true=0.02,
                 plot_estimate_after_epochs=6, plot_monitors_after_epochs=6, save_objects_after_epochs=12,
                 run_tag="clone_6")

exp.is_well_configured(debug_to=sys.stdout, enforce_with_exception=True)

exp.run(env)



Environment config:
    num_states  | 20
    x_min       | 0
    x_max       | 19
    num_actions | 3
    actions_dx  | [-1  0  1]
    r_per_turn  | -1.0
    r_per_dx    | 0.0
    r_per_b     | -1.0
    gamma       | 0.99
    Ns*Ns*Na    | 1200
Experiment config:
    max_mse_true                | 0.02
    max_epochs                  | 600
    batch_size                  | 10
    optimizer_type              | sgd
    learning_rate               | 0.01
    plot_estimate_after_epochs  | 6
    plot_monitors_after_epochs  | 6
    save_objects_after_epochs   | 12
    act_from_true_greedy_policy | False
    clone_after_epochs          | 6
    priority_weight             | 0.0
    num_step_returns            | 1
    run_tag                     | clone_6
    run_idx                     | 0
    num_priority_datapoints     | 38

Epoch 1 / 600  [t=14.10s]  <loss = 0.33415, mse_true = 61.0265 / 0.0200, max_abs_q = 1.0>          
Epoch 2 / 600  [t=25.68s]  <loss = 0.04349, mse_true = 58.6922 / 0.020

Epoch 71 / 600  [t=895.52s]  <loss = 2.16048, mse_true = 17.5944 / 0.0200, max_abs_q = 10.8>       
Epoch 72 / 600  [t=908.55s]  <loss = 2.15670, mse_true = 17.5440 / 0.0200, max_abs_q = 10.8>       
Epoch 73 / 600  [t=923.65s]  <loss = 2.37509, mse_true = 18.5752 / 0.0200, max_abs_q = 11.5>       
Epoch 74 / 600  [t=936.73s]  <loss = 2.36964, mse_true = 18.4662 / 0.0200, max_abs_q = 11.6>       
Epoch 75 / 600  [t=949.02s]  <loss = 2.36387, mse_true = 18.3690 / 0.0200, max_abs_q = 11.7>       
Epoch 76 / 600  [t=961.97s]  <loss = 2.35902, mse_true = 18.2844 / 0.0200, max_abs_q = 11.7>       
Epoch 77 / 600  [t=974.12s]  <loss = 2.35530, mse_true = 18.2120 / 0.0200, max_abs_q = 11.7>       
Epoch 78 / 600  [t=987.28s]  <loss = 2.35153, mse_true = 18.1443 / 0.0200, max_abs_q = 11.7>       
Epoch 79 / 600  [t=1000.92s]  <loss = 2.53990, mse_true = 19.3280 / 0.0200, max_abs_q = 12.4>      
Epoch 80 / 600  [t=1013.93s]  <loss = 2.53140, mse_true = 19.1863 / 0.0200, max_abs_q = 12.5>      


Epoch 153 / 600  [t=1937.80s]  <loss = 1.13063, mse_true = 11.8826 / 0.0200, max_abs_q = 21.3>     
Epoch 154 / 600  [t=1948.95s]  <loss = 1.12522, mse_true = 11.8847 / 0.0200, max_abs_q = 21.3>     
Epoch 155 / 600  [t=1962.31s]  <loss = 1.11952, mse_true = 11.8891 / 0.0200, max_abs_q = 21.3>     
Epoch 156 / 600  [t=1973.56s]  <loss = 1.11422, mse_true = 11.8940 / 0.0200, max_abs_q = 21.2>     
Epoch 157 / 600  [t=1988.07s]  <loss = 0.98130, mse_true = 11.2425 / 0.0200, max_abs_q = 21.0>     
Epoch 158 / 600  [t=1999.25s]  <loss = 1.02884, mse_true = 11.1971 / 0.0200, max_abs_q = 21.7>     
Epoch 159 / 600  [t=2010.61s]  <loss = 1.01285, mse_true = 11.1740 / 0.0200, max_abs_q = 21.6>     
Epoch 160 / 600  [t=2023.08s]  <loss = 1.00125, mse_true = 11.1624 / 0.0200, max_abs_q = 21.5>     
Epoch 161 / 600  [t=2034.37s]  <loss = 0.99140, mse_true = 11.1567 / 0.0200, max_abs_q = 21.5>     
Epoch 162 / 600  [t=2046.27s]  <loss = 0.98387, mse_true = 11.1553 / 0.0200, max_abs_q = 21.4>     


Epoch 235 / 600  [t=2933.43s]  <loss = 0.09691, mse_true = 1.8473 / 0.0200, max_abs_q = 21.2>      
Epoch 236 / 600  [t=2946.71s]  <loss = 0.06970, mse_true = 1.8277 / 0.0200, max_abs_q = 21.6>      
Epoch 237 / 600  [t=2959.70s]  <loss = 0.06914, mse_true = 1.8153 / 0.0200, max_abs_q = 21.6>      
Epoch 238 / 600  [t=2972.88s]  <loss = 0.06873, mse_true = 1.8057 / 0.0200, max_abs_q = 21.6>      
Epoch 239 / 600  [t=2984.15s]  <loss = 0.06837, mse_true = 1.7981 / 0.0200, max_abs_q = 21.6>      
Epoch 240 / 600  [t=2995.31s]  <loss = 0.06799, mse_true = 1.7918 / 0.0200, max_abs_q = 21.6>      
Epoch 241 / 600  [t=3011.08s]  <loss = 0.08360, mse_true = 1.5356 / 0.0200, max_abs_q = 21.1>      
Epoch 242 / 600  [t=3022.13s]  <loss = 0.05727, mse_true = 1.5187 / 0.0200, max_abs_q = 21.6>      
Epoch 243 / 600  [t=3034.77s]  <loss = 0.05679, mse_true = 1.5080 / 0.0200, max_abs_q = 21.6>      
Epoch 244 / 600  [t=3045.91s]  <loss = 0.05652, mse_true = 1.4998 / 0.0200, max_abs_q = 21.6>      


Epoch 317 / 600  [t=4004.06s]  <loss = 0.02504, mse_true = 0.2208 / 0.0200, max_abs_q = 20.0>      
Epoch 318 / 600  [t=4016.54s]  <loss = 0.02491, mse_true = 0.2200 / 0.0200, max_abs_q = 20.0>      
Epoch 319 / 600  [t=4030.68s]  <loss = 0.02577, mse_true = 0.1986 / 0.0200, max_abs_q = 19.9>      
Epoch 320 / 600  [t=4042.85s]  <loss = 0.02513, mse_true = 0.1977 / 0.0200, max_abs_q = 19.9>      
Epoch 321 / 600  [t=4055.19s]  <loss = 0.02494, mse_true = 0.1968 / 0.0200, max_abs_q = 19.9>      
Epoch 322 / 600  [t=4070.92s]  <loss = 0.02477, mse_true = 0.1961 / 0.0200, max_abs_q = 19.9>      
Epoch 323 / 600  [t=4083.35s]  <loss = 0.02462, mse_true = 0.1953 / 0.0200, max_abs_q = 19.9>      
Epoch 324 / 600  [t=4095.51s]  <loss = 0.02449, mse_true = 0.1946 / 0.0200, max_abs_q = 19.9>      
Epoch 325 / 600  [t=4110.09s]  <loss = 0.02513, mse_true = 0.1761 / 0.0200, max_abs_q = 19.8>      
Epoch 326 / 600  [t=4125.97s]  <loss = 0.02466, mse_true = 0.1754 / 0.0200, max_abs_q = 19.8>      


Epoch 399 / 600  [t=5135.48s]  <loss = 0.01855, mse_true = 0.0515 / 0.0200, max_abs_q = 19.3>      
Epoch 400 / 600  [t=5147.60s]  <loss = 0.01850, mse_true = 0.0515 / 0.0200, max_abs_q = 19.3>      
Epoch 401 / 600  [t=5159.79s]  <loss = 0.01845, mse_true = 0.0514 / 0.0200, max_abs_q = 19.3>      
Epoch 402 / 600  [t=5176.07s]  <loss = 0.01841, mse_true = 0.0513 / 0.0200, max_abs_q = 19.3>      
Epoch 403 / 600  [t=5189.95s]  <loss = 0.01813, mse_true = 0.0477 / 0.0200, max_abs_q = 19.3>      
Epoch 404 / 600  [t=5202.03s]  <loss = 0.01811, mse_true = 0.0476 / 0.0200, max_abs_q = 19.3>      
Epoch 405 / 600  [t=5214.08s]  <loss = 0.01807, mse_true = 0.0475 / 0.0200, max_abs_q = 19.3>      
Epoch 406 / 600  [t=5226.40s]  <loss = 0.01802, mse_true = 0.0474 / 0.0200, max_abs_q = 19.3>      
Epoch 407 / 600  [t=5238.42s]  <loss = 0.01798, mse_true = 0.0473 / 0.0200, max_abs_q = 19.3>      
Epoch 408 / 600  [t=5254.77s]  <loss = 0.01794, mse_true = 0.0473 / 0.0200, max_abs_q = 19.3>      


Epoch 481 / 600  [t=6212.28s]  <loss = 0.01275, mse_true = 0.0245 / 0.0200, max_abs_q = 19.3>      
Epoch 482 / 600  [t=6223.49s]  <loss = 0.01274, mse_true = 0.0244 / 0.0200, max_abs_q = 19.3>      
Epoch 483 / 600  [t=6234.68s]  <loss = 0.01272, mse_true = 0.0244 / 0.0200, max_abs_q = 19.3>      
Epoch 484 / 600  [t=6246.02s]  <loss = 0.01271, mse_true = 0.0243 / 0.0200, max_abs_q = 19.3>      
Epoch 485 / 600  [t=6257.54s]  <loss = 0.01269, mse_true = 0.0243 / 0.0200, max_abs_q = 19.3>      
Epoch 486 / 600  [t=6272.73s]  <loss = 0.01267, mse_true = 0.0243 / 0.0200, max_abs_q = 19.3>      
Epoch 487 / 600  [t=6285.25s]  <loss = 0.01242, mse_true = 0.0238 / 0.0200, max_abs_q = 19.3>      
Epoch 488 / 600  [t=6296.97s]  <loss = 0.01242, mse_true = 0.0237 / 0.0200, max_abs_q = 19.3>      
Epoch 489 / 600  [t=6308.17s]  <loss = 0.01240, mse_true = 0.0237 / 0.0200, max_abs_q = 19.3>      
Epoch 490 / 600  [t=6319.25s]  <loss = 0.01238, mse_true = 0.0236 / 0.0200, max_abs_q = 19.3>      
