- Monitors observed to follow damped oscillate when learning rate is too large, and we repeatedly overshoot the target even when the target is not especially mis-modelled

- Function appears to be dragged around, since its inductive bias is to be very smooth. This smoothness is reinforced by the target which is the rewards + y * bootstrap, which is itself smooth. This has two effects: 1. if the function is not very versatile then it will struggle to mould the near-to-terminal states into their correct positions, and all of the other datapoints override them in importance, and 2. it will struggle to capture harsh turning points in the value function. In this example, we can clearly see that in trying to model the "long-arm", which has more states and so receives a higher effective weight in the gradient update, we enact a lever-effect which pulls the "short arm" in the wrong direction. Ideally we would be updating the NN parameters to fold the value function in the middle and so describe both arms well.
    - it is possible that this effects occurs because of having too simply a NN, so will add capacity and see if it resolves
    - even if a lever-arm does not occur, we still focus on the arm with more states in, since these gradient updates take precedence when they act in opposite direction

- On the issue of using q1-and-q2 as different models, compared with just using q1 and a frozen version of itself to bootstrap from. In the first case, it seems that one function simply leads the other, i.e. we update q1 to a new iteration, then update q2 to catch up with q1, then iterate q1 again. Therefore we are not learning very efficiently, since 50% of the time we are duplicating progress. It is more efficient to freeze the original model and avoid this duplication. 

- Currently use SGD to avoid confounding learning momentum with the bias/divergence of FA + Q-learning, which is what I am trying to understand

- Maybe some of the massive jumps which occur seemingly out-of-nowhere occur because the greedy policy suddenly changes and starts selecting a different action, therefore dramatically changing the bootstrap target. Should highlight on the plot which action the bootstrap target selects each time. Might also just be due to learning rate being too high.

- "Good" solution works with both a simple NN and complicated one

In [1]:
#  Required imports

import math, os, pickle, sys, time

import numpy as np

from matplotlib import animation, pyplot as plt

from scipy import stats

from multiprocess import Process, Value
from threading import Lock, Thread

import threading

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers    import Conv2D, Concatenate, Dense, Dropout, Flatten, Input, MaxPooling2D, Rescaling
from tensorflow.keras.models    import Model
from tensorflow.keras.callbacks import EarlyStopping

print("TensorFlow has found devices:")
for device in tf.config.list_physical_devices() :
    print(f"-  {device}")
    
  # create global list of all threads we will create
all_threads = []




TensorFlow has found devices:
-  PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU')
-  PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')


In [2]:
###
###  Global constants
###

#  Initially we will just run on a game board of fixed size, to avoid building an architecture to handle 
#  variable board sizes, so let's configure this here

distance_to_end     = 13
game_board_pad_size = 3
horizontal_size     = distance_to_end + 2*game_board_pad_size + 1

reward_per_turn = -1.
lambda_dx       = 0.
lambda_b        = -1.
gamma           = 1.

x_min      = 0
x_max      = horizontal_size - 1
x_range    = x_max - x_min
x_start    = game_board_pad_size
x_end      = x_start + distance_to_end
num_states = x_range + 1

action_list  = [-1, 0, 1]
num_actions  = len(action_list)
action_list  = np.array(action_list).reshape((num_actions,1))
a_min, a_max = action_list.min(), action_list.max()
a_range      = a_max - a_min

print(f"Using game board config: x_min = {x_min}, x_max = {x_max}, x_start = {x_start}, x_end = {x_end} ({num_states} states)")
print(f"Using {num_actions} available actions: dx = {action_list.flatten()} with min {a_min} and max {a_max}")


Using game board config: x_min = 0, x_max = 19, x_start = 3, x_end = 16 (20 states)
Using 3 available actions: dx = [-1  0  1] with min -1 and max 1


In [3]:
###
###  Define environment methods
###


def is_terminal(x_agent) :
    '''
    Return True if the agent is in the terminal state and False otherwise.
    Inputs:
      > x_agent, int [x_min, x_max]
        x position of agent
    Returns:
      > bool
        whether the agent is in the terminal state
    '''
    if x_agent == x_end :
        return True
    return False


def is_out_of_bounds(x_agent) :
    '''
    Return True if the agent is out of bounds and False otherwise.
    Inputs:
      > x_agent, int [x_min, x_max]
        x position of agent
    Returns:
      > bool
        whether the agent is out of bounds
    '''
    if x_agent < x_min : return True
    if x_agent > x_max : return True
    return False


def perform_action(x_agent, action, base_reward=reward_per_turn, boundary_reward=lambda_b, dx_reward=lambda_dx) :
    '''
    Given the current environment and agent states, perform the specified action and return the reward 
    obtaine along with the new agent state.
    Inputs:
      > x_agent, int [x_min, x_max]
        x position of agent at initial timestep
      > action, int in action_list
        dx of action to be performed
      > base_reward, float, default=reward_per_turn
        basic reward returned every turn (expected -ve)
      > boundary_reward, float, default=lambda_b
        reward received when encountering the edge of the game board (expected -ve)
      > dx_reward, float, default=lambda_dx
        factor multiplied by change-in-distance to calculate movement reward (expected +ve def)
    Returns:
      > float
        reward obtained by performing action
      > int [0, horizontal_max)
        x position of agent at iterated timestep
    '''
    ##  Make sure initial state is valid to protect against unexpected behaviour
    if is_terminal(x_agent) :
        raise RuntimeError(f"Agent position is terminal, so no actions may be performed")
    if is_out_of_bounds(x_agent) :
        raise RuntimeError(f"Agent position ({x_agent}) is out of bounds, so no actions may be performed")
    ##  Make sure action is valid to protect against unexpected behaviour
    if action not in action_list :
        raise RuntimeError(f"Action ({action}) not found in available list ({action_list})")
    ##  Get initial distance of agent from the end
    dx_agent = np.fabs(x_agent - x_end)
    ##  Iterate agent position, if hit boundary then add penalty and return to original position 
    x_agent_p = x_agent + action
    reward_b  = 0
    if is_out_of_bounds(x_agent_p) :
        reward_b  = boundary_reward
        x_agent_p = x_agent.copy()
    ##  Get distance-based reward
    dx_agent_p = np.fabs(x_agent_p - x_end)
    reward_dx  = dx_reward * (dx_agent - dx_agent_p)
    ##  Calculate total reward by summing the base, boundary, distance and weather rewards
    reward = base_reward + reward_b + reward_dx
    ##  Return reward and new agent state
    return reward, x_agent_p
    

def get_greedy_action(x_agent, *q_models) :
    '''
    Sample a greedy action from the q-value models provided. If multiple models provided then use their mean.
    Inputs:
      > x_agent, int [x_min, x_max]
        x position of agent at initial timestep
      > q_models, list of tf.keras Model class, each with inputs [x_agent, action] = [Input(1), Input(1)]
        list of Keras q(s,a) models
    Returns:
      > int in action_list
        action defined by greedy policy over the model(s) at this agent position
      > list of np.ndarray objects of shape (num_actions,)
        action values in the same order as action_list, once for each model provided
    '''
    x_agents            = np.array([x_agent for i in range(num_actions)]).reshape((num_actions, 1))
    model_args          = [x_agents, action_list]
    model_action_values = [model.predict(model_args).flatten() for model in q_models]
    action_values       = np.mean(model_action_values, axis=0)
    best_action         = action_list[np.argmax(action_values)][0]
    return best_action


def get_state_action_pairs() :
    state_action_pairs = []
    for x_agent in range(x_min, x_max+1) : 
        if is_terminal(x_agent) : continue
        if is_out_of_bounds(x_agent) :
            raise RuntimeError(f"Trying to add out-of-bounds state x={x_agent} to state_action_pairs")
        for action in action_list.flatten() :
            state_action_pairs.append((x_agent, action))
    return np.array(state_action_pairs)


def get_true_q(states, actions) :
    num_states = len(states)
    q_values = np.zeros(shape=(num_states,))
    for x_idx, (s, a) in enumerate(zip(states, actions)) :
        if is_terminal(s) :
            q_values[x_idx] = np.nan
            continue
        g, s = perform_action(s, a)
        while not is_terminal(s) :
            if s > x_end : a = -1
            else         : a = 1
            r, s = perform_action(s, a)
            g += r
        q_values[x_idx] = g
    return q_values


def get_mse(q_values_1, q_values_2) :
    q_res = q_values_2 - q_values_1
    q_res = np.where(np.isfinite(q_res), q_res, 0)
    q_res = q_res**2
    return np.mean(q_res)
    

In [4]:
def generate_directory_for_file_path(fname, print_msg_on_dir_creation=True) :
    """
    Create the directory structure needed to place file fname. Call this before fig.savefig(fname, ...) to 
    make sure fname can be created without a FileNotFoundError
    Input:
       - fname: str
                name of file you want to create a tree of directories to enclose
                also create directory at this path if fname ends in '/'
       - print_msg_on_dir_creation: bool, default = True
                                    if True then print a message whenever a new directory is created
    """
    while "//" in fname :
        fname = fname.replace("//", "/")
    dir_tree = fname.split("/")
    dir_tree = ["/".join(dir_tree[:i]) for i in range(1,len(dir_tree))]
    dir_path = ""
    for dir_path in dir_tree :
        if len(dir_path) == 0 : continue
        if not os.path.exists(dir_path) :
            os.mkdir(dir_path)
            if print_msg_on_dir_creation :
                print(f"Directory {dir_path} created")
            continue
        if os.path.isdir(dir_path) : 
            continue
        raise RuntimeError(f"Cannot create directory {dir_path} because it already exists and is not a directory")
    

def create_config(config_fname, run_config, q1_model, q2_model, optimizer_q1, optimizer_q2, to_stdout=True) :
    '''
    Print environment, training and model configurations to file config_fname. Also print environment and
    training configurations to sys.stdout if requested, but do not print model summaries as they are verbose.
    Inputs:
      > config_fname, str
        name of config file to create
      > q1_model, keras Model
        first q-value model
      > q2_model, keras Model
        second q-value model
      > to_stdout, bool, default=True
        if True then repeat environment and training configurations to sys.stdout
    Returns:
      > None
    '''
    # Create message as list of strings
    config_message = []
    config_message.append(f"="*114 + "\n")
    config_message.append(f"Environment config:\n")
    config_message.append(f"> distance_to_end: {distance_to_end}\n")
    config_message.append(f"> game_board_pad_size: {game_board_pad_size}\n")
    config_message.append(f"> horizontal_size: {horizontal_size}\n")
    config_message.append(f"> reward_per_turn: {reward_per_turn}\n")
    config_message.append(f"> lambda_dx: {lambda_dx}\n")
    config_message.append(f"> lambda_b: {lambda_b}\n")
    config_message.append(f"> gamma: {gamma}\n")
    config_message.append(f"> x_min: {x_min}\n")
    config_message.append(f"> x_max: {x_max}\n")
    config_message.append(f"> x_range: {x_range}\n")
    config_message.append(f"> x_start: {x_start}\n")
    config_message.append(f"> x_end: {x_end}\n")
    config_message.append(f"> action_list: {action_list.flatten()}\n")
    config_message.append(f"="*114 + "\n")
    config_message.append(f"Training config:\n")
    config_message.append(f"> Stop training when mse_true exceeds {run_config.get('max_mse_true')}\n")
    config_message.append(f"> Using {run_config.get('num_step_returns')} step empirical returns\n")
    config_message.append(f"> Using bootstrap method: {run_config.get('bootstrap_method')}\n")
    config_message.append(f"> Using epochs of length {run_config.get('num_state_action_pairs')}\n")
    config_message.append(f"> Updating gradient every batch of size {run_config.get('batch_size')}\n")
    config_message.append(f"> Using optimizer_q1 {optimizer_q1} with learning rate {run_config.get('learning_rate'):.6}\n")
    config_message.append(f"> Using optimizer_q2 {optimizer_q2} with learning rate {run_config.get('learning_rate'):.6}\n")
    config_message.append(f"> Swapping q1 and q2 every {run_config.get('switch_after_epochs')} epochs\n")
    config_message.append(f"> Cloning q2 from q1 every {run_config.get('clone_after_epochs')} epochs\n")
    config_message.append(f"> Assigning a weight of {run_config.get('priority_weight')} to anchoring state/action pairs\n")
    config_message.append(f"="*114 + "\n")
    # Make sure directory exists for file
    generate_directory_for_file_path(config_fname, print_msg_on_dir_creation=True)
    # Open file and print messages, also to stdout if configured
    # - also print q-model summaries, only to file
    with open(config_fname, "w") as config_file :
        for line in config_message :
            config_file.write(line)
            if not to_stdout : continue
            sys.stdout.write(line)
        config_file.write("\nModel configs:\n\n")
        q1_model.summary(print_fn=lambda x: config_file.write(x + '\n'))
        config_file.write("\n")
        q2_model.summary(print_fn=lambda x: config_file.write(x + '\n'))
        
        
def create_value_estimate_plot(test_states, true_q, target_q, q_model, bs_model, epoch_idx=-1, 
                               show=False, close=False, verbose=False, save="", dpi=100) :
    '''
    Create a plt.Figure instance visualising the greedy policy defined by the average of the q-value models 
    provided. Allows for plot to be shown, saved and/or closed using plt interface. Returns the plot figure
    and axis objects so they can continue to be manipulated, but note that objects will no longer be in scope
    if we have called plt.close(fig).
    Inputs:
      > test_states, np.ndarray of size (num_states,)
        states used to evaluate models
      > true_q, np.array of shape (3*num_states,)
        true q-values in concatenated list of actions = [-1, 0, 1]
      > target_q, np.array of shape (3*num_states,)
        target q-values in concatenated list of actions = [-1, 0, 1]
      > q_model, np.array of shape (3*num_states,)
        estimated q-values in concatenated list of actions = [-1, 0, 1]
      > bs_model, np.array of shape (3*num_states,)
        bootstrap q-values in concatenated list of actions = [-1, 0, 1]
      > epoch_idx, int, default=-1
        if positive then draw a text box displaying how many epochs have been performed
      > show, bool, default=False
        if True then call plt.show(fig)
      > close, bool, default=False
        if True then call plt.close(fig)
      > save, str, default=""
        if string provided then call fig.savefig(save, ...), creating any required subdirectories if needed
    Returns:
      > plt.Figure instance
        Figure object
      > plt.Axes instance
        Left-hand axis object
      > plt.Axes instance
        Middle axis object
      > plt.Axes instance
        Right-hand axis object
    '''
    
    num_test_states = len(test_states)
     
    #  Keep track of how long plotting takes, to help inform how often to call this function    
    start_time = time.time()

    #  Make plot
    fig = plt.figure(figsize=(14, 6))
    fig.set_facecolor("white")
    fig.set_alpha(1)
    
    ax1 = fig.add_subplot(1, 3, 1)
    ax1.tick_params(axis="both", which="both", right=True, top=True, direction="in", labelsize=12)
    ax1.plot(test_states, q_model [:num_test_states], "o-" , c="r"         , ms=5, lw=3, alpha=0.5, label="Estimated $q(s,a)$")
    ax1.plot(test_states, bs_model[:num_test_states], "x-" , c="b"         , ms=5, lw=3, alpha=0.5, label="Bootstrap")
    ax1.plot(test_states, target_q[:num_test_states], "x-" , c="darkorange", ms=5, lw=3, alpha=0.5, label="Target")
    ax1.plot(test_states, true_q  [:num_test_states], ".--", c="gray"      , ms=5, lw=3, alpha=0.5, label="True")
    ax1.grid(True, which='both')
    ax1.set_xlabel("$x$", labelpad=15, fontsize=14)
    
    ax2 = fig.add_subplot(1, 3, 2)
    ax2.tick_params(axis="both", which="both", right=True, top=True, direction="in", labelsize=12)
    ax2.plot(test_states, q_model [num_test_states:2*num_test_states], "o-" , c="r"         , ms=5, lw=3, alpha=0.5, label="Estimated $q(s,a)$")
    ax2.plot(test_states, bs_model[num_test_states:2*num_test_states], "x-" , c="b"         , ms=5, lw=3, alpha=0.5, label="Bootstrap")
    ax2.plot(test_states, target_q[num_test_states:2*num_test_states], "x-" , c="darkorange", ms=5, lw=3, alpha=0.5, label="Target")
    ax2.plot(test_states, true_q  [num_test_states:2*num_test_states], ".--", c="gray"      , ms=5, lw=3, alpha=0.5, label="True")
    ax2.grid(True, which='both')
    ax2.set_xlabel("$x$", labelpad=15, fontsize=14)
    
    ax3 = fig.add_subplot(1, 3, 3)
    ax3.tick_params(axis="both", which="both", right=True, top=True, direction="in", labelsize=12)
    ax3.plot(test_states, q_model [2*num_test_states:3*num_test_states], "o-" , c="r"         , ms=5, lw=3, alpha=0.5, label="Estimated $q(s,a)$")
    ax3.plot(test_states, bs_model[2*num_test_states:3*num_test_states], "x-" , c="b"         , ms=5, lw=3, alpha=0.5, label="Bootstrap")
    ax3.plot(test_states, target_q[2*num_test_states:3*num_test_states], "x-" , c="darkorange", ms=5, lw=3, alpha=0.5, label="Target")
    ax3.plot(test_states, true_q  [2*num_test_states:3*num_test_states], ".--", c="gray"      , ms=5, lw=3, alpha=0.5, label="True")
    ax3.grid(True, which='both')
    ax3.set_xlabel("$x$", labelpad=15, fontsize=14)
    
    #  Find string representing bootstrap greedy policy
    str_bs_greedy_policy = "Bootstrap policy: "
    for s in test_states :
        if is_terminal(s) : 
            str_bs_greedy_policy += "  |"
            continue
        qL, q0, qR = bs_model[s], bs_model[num_test_states+s], bs_model[2*num_test_states+s]
        if   qL > q0 and qL > qR : str_bs_greedy_policy += "  L"
        elif q0 > qL and q0 > qR : str_bs_greedy_policy += "  0"
        elif qR > qL and qR > q0 : str_bs_greedy_policy += "  R"
        else : str_bs_greedy_policy += "  ?"
             
    #  Draw accompanying plot objects
    ax1.legend(loc=(0.7,1.06), ncol=4, fontsize=14, frameon=False)
    ax1.axhline(0, lw=1, c="k", ls="-")
    ax2.axhline(0, lw=1, c="k", ls="-")
    ax3.axhline(0, lw=1, c="k", ls="-")
    ax1.text(0.01, 1.01, f"Action: left" , ha="left", va="bottom", weight="bold", transform=ax1.transAxes, 
             alpha=0.8, fontsize=12, c="k")
    ax2.text(0.01, 1.01, f"Action: stay" , ha="left", va="bottom", weight="bold", transform=ax2.transAxes, 
             alpha=0.8, fontsize=12, c="k")
    ax3.text(0.01, 1.01, f"Action: right", ha="left", va="bottom", weight="bold", transform=ax3.transAxes, 
             alpha=0.8, fontsize=12, c="k")
    ax1.text(0, -0.2, f"{str_bs_greedy_policy}", ha="left", va="top", weight="bold", transform=ax1.transAxes, fontsize=12, c="k")
        
    #  Figure out and set y-axis ranges
    true_q, q_model, bs_model
    y_min   = np.nanmin([0, np.nanmin(true_q), np.nanmin(target_q), np.nanmin(q_model), np.nanmin(bs_model)])
    y_max   = np.nanmax([0, np.nanmax(true_q), np.nanmax(target_q), np.nanmax(q_model), np.nanmax(bs_model)])
    y_range = y_max - y_min
    y_pad   = 0.1
    y_lim   = [y_min - y_pad*y_range, y_max + y_pad*y_range]
    ax1.set_ylim(y_lim)
    ax2.set_ylim(y_lim)
    ax3.set_ylim(y_lim)
    
    #  Draw text boxes displaying title and num. epochs
    if epoch_idx >= 0 :
        ax1.text(0., 1.08, f"After {epoch_idx} epochs", ha="left", va="bottom", weight="bold", 
                 transform=ax1.transAxes, fontsize=14)
       
    #  Save / show / close
    if len(save) > 0 :
        generate_directory_for_file_path(save, print_msg_on_dir_creation=verbose)
        plt.savefig(save, bbox_inches="tight", dpi=dpi)
    if show :
        plt.show(fig)
    if close :
        plt.close(fig)
        
    #  Return figure and axis
    return fig, ax1, ax2, ax3

        
def create_training_curves_plot(loss_record, ref_loss_record, maxQ_record, true_max_Q=np.nan, 
                                show=False, close=False, verbose=False, save="", dpi=100) :
    '''
    Create a plt.Figure instance visualising the training curves. Allows for plot to be shown, saved and/or 
    closed using plt interface. Returns the plot figure and axis objects so they can continue to be 
    manipulated, but note that objects will no longer be in scope if we have called plt.close(fig).
    Inputs:
      > q_models, list of keras Model class
        list of q-value models to define the greedy policy
      > verbose, bool, default=False
        if True then print some text to display progress as we evaluate the models for every state/action pair
      > show, bool, default=False
        if True then call plt.show(fig)
      > close, bool, default=False
        if True then call plt.close(fig)
      > save, str, default=""
        if string provided then call fig.savefig(save, ...), creating any required subdirectories if needed
    Returns:
      > plt.Figure instance
      > plt.Axes instance (axis corresponding to loss curves)
      > plt.Axes instance (axis corresponding to ref_loss curves)
      > plt.Axes instance (axis corresponding to maxQ curves)
    '''
    
    def draw_curve(ax, container, m, c, label) :
        ax.plot([x for x,y in container], [y for x,y in container], m, ms=7, c=c, alpha=1.0, label=label)
        for ((x1,y1), (x2,y2)) in zip(container[:-1], container[1:]) :
            if np.fabs(x2-x1) > 1.5 : continue
            ax.plot([x1, x2], [y1, y2], "-", c=c, lw=2, alpha=0.7)
            
    fig = plt.figure(figsize=(30,15))
    fig.set_facecolor("white")
    fig.set_alpha(1)
    
    ax1 = fig.add_subplot(3, 1, 1)
    ax1.grid(True, which='both')
    ax1.tick_params(axis="both", which="both", right=True, top=True, direction="in", labelsize=30)
    ax1.set_title(r"Mean loss per batch [$(1-\lambda)\cdot$batch + $\lambda\cdot$ref]", fontsize=30)
    ax1.xaxis.set_ticklabels([])
    draw_curve(ax1, loss_record["Q1"], "o", "r", "$q_1$")
    draw_curve(ax1, loss_record["Q2"], "x", "b", "$q_2$")
    ax1.set_yscale("log")
    ax1.legend(loc=(0,1.02), fontsize=30, ncol=3, title_fontsize=30)
    
    ax2 = fig.add_subplot(3, 1, 2)
    ax2.grid(True, which='both')
    ax2.tick_params(axis="both", which="both", right=True, top=True, direction="in", labelsize=30)
    ax2.set_title(r"Mean loss per batch [ref only]", fontsize=30)
    ax2.xaxis.set_ticklabels([])
    draw_curve(ax2, ref_loss_record["Q1"], "o", "r", "$q_1$")
    draw_curve(ax2, ref_loss_record["Q2"], "x", "b", "$q_2$")
    ax2.set_yscale("log")
    
    ax3 = fig.add_subplot(3, 1, 3)
    ax3.grid(True, which='both')
    ax3.tick_params(axis="both", which="both", right=True, top=True, direction="in", labelsize=30)
    ax3.set_title(r"Max $|q(s,a)|$ over all batches", fontsize=30)
    ax3.set_xlabel(r"Epoch", labelpad=15, fontsize=30)
    draw_curve(ax3, maxQ_record["Q1"], "o", "r", "$q_1$")
    draw_curve(ax3, maxQ_record["Q2"], "x", "b", "$q_2$")
    ax3.axhline(0, ls="--", lw=2, c="gray")
    if np.isfinite(true_max_Q) :
        ax3.axhline(true_max_Q, ls="--", lw=2, c="gray")
        ax3.text(0, true_max_Q, "True maximum", fontsize=20, ha="left", va="top", c="k")
    
    fig.subplots_adjust(hspace=0.2)
    
    if len(save) > 0 :
        generate_directory_for_file_path(save, print_msg_on_dir_creation=verbose)
        plt.savefig(save, bbox_inches="tight", dpi=dpi)
    if show :
        plt.show(fig)
    if close :
        plt.close(fig)
        
    return fig, ax1, ax2, ax3
    

In [5]:
###
###  Method for creating action-value model
###

def create_action_value_model(name=None) :
    '''
    Create a network for the action-value model.
    Inputs:
      > name, str, default=None
        model name, if None then keras default is used
    Returns:
      > keras Model: uncompiled keras model (must be trained using custom loop)
    '''
    input_layer_x = Input ((1,))
    input_layer_a = Input ((1,))
    next_layer_x  = Rescaling(2./x_range, offset=-(x_max+x_min)/x_range)(input_layer_x)
    next_layer_a  = Rescaling(2./a_range, offset=-(a_max+a_min)/a_range)(input_layer_a)
    next_layer_x  = Dense(25, activation="relu")(next_layer_x)
    next_layer_a  = Dense(25, activation="relu")(next_layer_a)
    next_layer    = Concatenate()([next_layer_x, next_layer_a])
    next_layer    = Dense(100, activation="relu")(next_layer)
    next_layer    = Dense(100, activation="relu")(next_layer)
    output_layer  = Dense(1, activation="linear")(next_layer)
    model         = Model([input_layer_x, input_layer_a], output_layer, name=name)
    model.compile(loss="mse", optimizer="sgd")
    return model
    

In [6]:
###
###  Identify priority state/action pairs
###

priority_states, priority_actions, priority_returns = [], [], []
num_priority_datapoints = 0
for action in action_list :
    x_initial = x_end - action
    if is_terminal(x_initial) : continue
    if is_out_of_bounds(x_initial) : continue
    reward, _ = perform_action(x_initial, action)
    num_priority_datapoints += 1
    priority_states .append(x_initial)
    priority_actions.append(action)
    priority_returns.append(reward)
    
priority_states  = np.array(priority_states ).reshape((num_priority_datapoints, 1))
priority_actions = np.array(priority_actions).reshape((num_priority_datapoints, 1))
priority_returns = np.array(priority_returns).reshape((num_priority_datapoints, 1))

print(f"Found {num_priority_datapoints} priority state-action pairs with returns:  {'  '.join([f'{r:.2f}' for r in priority_returns.flatten()])}")


Found 2 priority state-action pairs with returns:  -1.00  -1.00


In [7]:
###
###  Experiment initialisation functions
###

def initialise_keras_objects(bootstrap_method, learning_rate, tag="no_tag") :
    loss_fcn     = tf.keras.losses.MeanSquaredError()
    q1_model     = create_action_value_model(name=f"action_value_model_1_{tag}")
    optimizer_q1 = tf.keras.optimizers.SGD(learning_rate=learning_rate)
    if bootstrap_method == "self" :
        q2_model     = q1_model
        optimizer_q2 = None
    elif bootstrap_method == "clone" :
        q2_model     = create_action_value_model(name=f"action_value_model_1_{tag}_clone")
        optimizer_q2 = None
    elif bootstrap_method == "self" :
        q2_model     = create_action_value_model(name=f"action_value_model_{tag}_2")
        optimizer_q2 = tf.keras.optimizers.SGD(learning_rate=learning_rate)
    else :
        raise NotImplementedError
    return loss_fcn, q1_model, q2_model, optimizer_q1, optimizer_q2


In [8]:
###
###  Get global constants for experiments
###

# State-action pairs from which to form batches

all_state_action_pairs = get_state_action_pairs()
num_state_action_pairs = len(all_state_action_pairs)

# True value function, also 'test' objects for evaluating keras models for comparison later

test_states_flat = np.arange(x_min, x_max+1)
num_test_states  = len(test_states_flat)
test_states      = test_states_flat.reshape((num_test_states,1))
test_states      = np.concatenate([test_states, test_states, test_states])
test_actions_L   = np.full(fill_value=-1, shape=(num_test_states,1))
test_actions_0   = np.full(fill_value=0 , shape=(num_test_states,1))
test_actions_R   = np.full(fill_value=1 , shape=(num_test_states,1))
test_actions     = np.concatenate([test_actions_L, test_actions_0, test_actions_R])
true_q           = get_true_q(test_states.flatten(), test_actions.flatten())
true_max_abs_q   = np.nanmax(np.fabs(true_q))


In [9]:

def run_experiment(run_config, run_idx=1, verbose=True, epoch_tracker=None) :
            #  Extract configuration variables from dict
    max_mse_true               = run_config.get("max_mse_true", -np.inf)
    max_epochs                 = run_config.get("max_epochs")
    batch_size                 = run_config.get("batch_size")
    learning_rate              = run_config.get("learning_rate")
    test_after_epochs          = run_config.get("test_after_epochs")
    plot_estimate_after_epochs = run_config.get("plot_estimate_after_epochs")
    plot_monitors_after_epochs = run_config.get("plot_monitors_after_epochs")
    save_objects_after_epochs  = run_config.get("save_objects_after_epochs")
    switch_after_epochs        = run_config.get("switch_after_epochs")
    clone_after_epochs         = run_config.get("clone_after_epochs")
    priority_weight            = run_config.get("priority_weight")
    bootstrap_method           = run_config.get("bootstrap_method")
    num_step_returns           = run_config.get("num_step_returns")
    run_tag                    = run_config.get("run_tag")
            #  Make sure bootstrap method is valid
    if bootstrap_method not in ["clone", "self", "other"] :
        raise NotImplementedError(f"Bootstrap method {bootstrap_method} not implemented")
            #  Resolve top directory based on configured run tag and run index
    top_directory = f"figures/Helicopter_NB0/{run_tag}/experiment_{run_idx}"
            #  Initialise keras objects
    loss_fcn, q1_model, q2_model, optimizer_q1, optimizer_q2 = initialise_keras_objects(bootstrap_method, learning_rate, tag=f"{run_tag}_{run_idx}")
            #  Print config to file
    create_config(f"{top_directory}/config.txt", run_config, q1_model, q2_model, optimizer_q1, optimizer_q2, to_stdout=verbose)
            #  Initialise monitors
    model_key_q1, model_key_q2 = "Q1", "Q2"  # used to keep track of which model is being traing each epoch
    loss_record     = {model_key_q1:[], model_key_q2:[]}
    ref_loss_record = {model_key_q1:[], model_key_q2:[]}
    maxQ_record     = {model_key_q1:[], model_key_q2:[]}
    epochs_record   = [0]
    q1_record       = [q1_model.predict([test_states, test_actions]).flatten()]
    q2_record       = [q2_model.predict([test_states, test_actions]).flatten()]
    mse_true        = get_mse(q1_record[-1], true_q)
    mse_true_record = []
            #  Create methods to help monitoring
    def test_and_record(epoch_idx) :
        epochs_record.append(epoch_idx)
        q1_record      .append(q1_model.predict([test_states, test_actions]).flatten())
        q2_record      .append(q2_model.predict([test_states, test_actions]).flatten())
        mse_true_record.append(get_mse(q1_record[-1], true_q))
        return mse_true_record[-1]
    def plot_value_functions(epoch_idx) :
        create_value_estimate_plot(test_states_flat, true_q, target_q, q1_record[-1], q2_record[-1], 
                                   epoch_idx=epoch_idx, verbose=verbose, show=False, close=True, 
                                   save=f"{top_directory}/value_estimates_epoch{epoch_idx}.png")
    def plot_training_curves() :
        create_training_curves_plot(loss_record, ref_loss_record, maxQ_record, true_max_Q=true_max_abs_q,
                                    verbose=verbose, show=False, close=True, 
                                    save=f"{top_directory}/training_curves.pdf")
    def save_objects() :
        to_save = {"run_config":run_config, "run_idx":run_idx, "epoch_idx":epoch_idx, "loss_record":loss_record,
                   "ref_loss_record":ref_loss_record, "maxQ_record":maxQ_record, "epochs_record":epochs_record,
                   "q1_record":q1_record, "q2_record":q2_record}
        fname   = f"{top_directory}/saved_objects.pickle"
        generate_directory_for_file_path(fname, print_msg_on_dir_creation=verbose)
        pickle.dump(to_save, open(fname,"wb"))
        tf_log_level = tf.get_logger().level
        tf.get_logger().setLevel('WARNING')
        q1_model.save(f"{top_directory}/q1_model")
        q2_model.save(f"{top_directory}/q2_model")
        tf.get_logger().setLevel(tf_log_level)
    
            #  Start training
    epoch_idx, start_time  = 0, time.time()
    state_action_pairs     = all_state_action_pairs.copy()
    state_action_pair_idcs = np.arange(num_state_action_pairs)
    target_q               = np.ones_like(true_q)*np.nan
    while (epoch_idx < max_epochs or max_epochs < 0) and mse_true > max_mse_true :

        # Determine whether to test and save value function estimates
        if test_after_epochs > 0 and epoch_idx % test_after_epochs == 0 :
            mse_true = test_and_record(epoch_idx)

        # Determine whether to plot value function estimates
        if plot_estimate_after_epochs > 0 and epoch_idx % plot_estimate_after_epochs == 0 :
            if epochs_record[-1] != epoch_idx :
                print(f"WARNING: plot value function is out of date at epoch index ({epochs_record[-1]} < {epoch_idx})")
            plot_value_functions(epoch_idx)
            
        # Determine whether to plot training curves
        if plot_monitors_after_epochs > 0 and epoch_idx > 0 and epoch_idx % plot_monitors_after_epochs == 0 :
            plot_training_curves()

        # Determine whether to save objects
        if save_objects_after_epochs > 0 and epoch_idx % save_objects_after_epochs == 0 :
            save_objects()
            
        # Determine whether to switch q1 and q2
        if bootstrap_method == "other" and switch_after_epochs > 0 and epoch_idx > 0 and epoch_idx % switch_after_epochs == 0 :
            model_key_q1, model_key_q2 = model_key_q2, model_key_q1
            q1_model, q2_model         = q2_model, q1_model
            optimizer_q1, optimizer_q2 = optimizer_q2, optimizer_q1

        # Determine whether to copy q1 to q2
        if bootstrap_method == "clone" and clone_after_epochs > 0 and epoch_idx % clone_after_epochs == 0 :
            q2_model.set_weights(q1_model.get_weights()) 

        if verbose :
            sys.stdout.write(f"Epoch {epoch_idx+1} / {max_epochs}  [t={time.time()-start_time:.2f}s]")

        # Shuffle states and loop over batches, performing one gradient update per batch
        # - this has lower variance than applying one gradient update per sample, and also allows parallelisation
        np.random.shuffle(state_action_pair_idcs)
        state_action_pairs = state_action_pairs[state_action_pair_idcs]
        epoch_losses, ref_losses, max_abs_q_values = [], [], []
        for batch_idx in range(math.ceil(num_state_action_pairs/batch_size)) :
            # Resolve sample indices to be used for this batch update
            batch_idx_low, batch_idx_high = batch_idx*batch_size, min((batch_idx+1)*batch_size, num_state_action_pairs)
            actual_batch_size = batch_idx_high - batch_idx_low
            if actual_batch_size == 0 : continue

            # Update the current epoch message to keep track of batch progress 
            if verbose :
                sys.stdout.write(f"\rEpoch {epoch_idx+1} / {max_epochs} batch indices ({batch_idx_low}, {batch_idx_high}) / {num_state_action_pairs}  [t={time.time()-start_time:.2f}s]")

            # Get batch of states and generate a random exploration action for each
            batch_state_action_pairs = state_action_pairs[batch_idx_low:batch_idx_high]

            # For each state/action pair, apply the action to get the immediate reward, and also find the
            # greedy action in the next state from which to calculate bootstraps
            batch_emp_returns, batch_states_p, batch_actions_p, batch_state_p_is_terminal = [], [], [], []
            for state, action in batch_state_action_pairs :
                state_p, action_p, emp_return, step_discount = state, action, 0., 1.
                for step_idx in range(num_step_returns) :
                    if not is_terminal(state_p) :
                        step_reward, state_p = perform_action(state_p, action_p)
                        action_p             = get_greedy_action(state_p, q1_model)
                        emp_return          += step_discount*step_reward
                    step_discount *= gamma
                batch_emp_returns        .append(emp_return)
                batch_states_p           .append([state_p])
                batch_actions_p          .append([action_p])
                batch_state_p_is_terminal.append(True if is_terminal(state_p) else False)
            batch_emp_returns, batch_states_p            = np.array(batch_emp_returns), np.array(batch_states_p)
            batch_actions_p  , batch_state_p_is_terminal = np.array(batch_actions_p  ), np.array(batch_state_p_is_terminal)

            # Get inputs and ensure correct shapes
            batch_states  = batch_state_action_pairs[:,0].reshape((actual_batch_size,1))
            batch_actions = batch_state_action_pairs[:,1].reshape((actual_batch_size,1))

            # Calculate obs_returns using the observed immediate rewards plut gamma * bootstrap values
            batch_bootstrap_values = q2_model.predict([batch_states_p, batch_actions_p]).flatten()
            batch_bootstrap_values = np.array([0. if is_term else q for is_term,q in zip(batch_state_p_is_terminal,batch_bootstrap_values)])        
            batch_obs_returns      = batch_emp_returns + step_discount * batch_bootstrap_values
            
            # Make sure output shapes are correctly formatted                    
            batch_obs_returns = batch_obs_returns.reshape((actual_batch_size,1))

            # Store target q for later plot (tricky to resolve correct index since we shuffled state-action pairs)
            for sample_idx, (state, action) in enumerate(batch_state_action_pairs) :
                test_idx = num_states * (action+1) + state
                target_q[test_idx] = batch_obs_returns[sample_idx]

            # Calculate weights to apply to regular batch and priority samples
            batch_weights    = np.full(shape=(actual_batch_size      ,), fill_value=1./actual_batch_size      )
            priority_weights = np.full(shape=(num_priority_datapoints,), fill_value=1./num_priority_datapoints)

            # Concatenate regular batch and priority samples
            train_weights, train_states, train_actions, train_obs_returns = batch_weights, batch_states, batch_actions, batch_obs_returns
            if priority_weight > 0 and priority_weight <= 1 :
                batch_weights    *= (1.-priority_weight)
                priority_weights *= priority_weight
                train_weights     = np.concatenate([batch_weights    , priority_weights])
                train_states      = np.concatenate([batch_states     , priority_states ])
                train_actions     = np.concatenate([batch_actions    , priority_actions])
                train_obs_returns = np.concatenate([batch_obs_returns, priority_returns])

            # Train weights must be normalised to num samples instead of 1 to recover the expected loss value
            train_weights = train_weights * len(train_weights) / train_weights.sum()

            # ref_loss is the loss over only the priority state/action pairs, before gradient updates
            ref_returns = q1_model([priority_states, priority_actions], training=False)
            ref_loss    = loss_fcn(priority_returns.reshape((num_priority_datapoints,1)), ref_returns.numpy())
            ref_losses.append(ref_loss)

            # Apply gradient updates and store monitor values
            with tf.GradientTape() as tape:
                train_pred_returns = q1_model([train_states, train_actions], training=True)
                loss_value         = loss_fcn(train_obs_returns, train_pred_returns, sample_weight=train_weights)
                grads              = tape.gradient(loss_value, q1_model.trainable_weights)
                optimizer_q1.apply_gradients(zip(grads, q1_model.trainable_weights))
                epoch_losses.append(loss_value.numpy())
                max_abs_q_values.append(np.fabs(train_pred_returns.numpy()).max())

        # Print epoch summary and end stdout line to keep this on screen
        epoch_mean_loss, epoch_mean_ref_loss, epoch_maxQ = np.mean(epoch_losses), np.mean(ref_losses), np.max(max_abs_q_values)
        if verbose :
            sys.stdout.write(f"\rEpoch {epoch_idx+1} / {max_epochs}  [t={time.time()-start_time:.2f}s]  <loss = {epoch_mean_loss:.5f}, ref_loss = {epoch_mean_ref_loss:.5f}, max_Q = {epoch_maxQ:.1f},  mse_true = {mse_true:.3f}>".ljust(100)+"\n")
        loss_record    [model_key_q1].append((epoch_idx, epoch_mean_loss    ))
        ref_loss_record[model_key_q1].append((epoch_idx, epoch_mean_ref_loss))
        maxQ_record    [model_key_q1].append((epoch_idx, epoch_maxQ         ))

        # Manually iterate epoch index and make sure stdout not lagging
        if type(epoch_tracker) != type(None) :
            with epoch_tracker.get_lock():
                epoch_tracker.value += 1
        epoch_idx += 1
        
    # Save final results
    test_and_record(epoch_idx)
    plot_value_functions(epoch_idx)
    plot_training_curves()
    save_objects()
        

In [10]:
###
###  Methods for running experiments in threads
###


def kill_threads(threads=None, verbose=False) :
    '''
    Kills all running threads
    '''
    if type(threads) is type(None) :
        global all_threads
        threads = all_threads
    for thread in threads :
        if not hasattr(thread, "kill") :
            continue
        thread.kill(verbose=verbose)
        
        
class BaseThread(Thread) :
    
    def __init__(self):
        Thread.__init__(self)
        self.killed = False
        self.lock   = Lock()
        global all_threads
        all_threads.append(self)
        
    def kill(self, killed=True, verbose=False) :
        if not self.killed and killed and verbose :
            self.info(f"Killing thread: {self.name}\n")
        self.killed = killed
            
    def info(self, message) :
        self.lock.acquire()
        sys.stdout.write(f"{message}")
        sys.stdout.flush()
        self.lock.release()
    
    
class WorkerThread(BaseThread):
    
    def __init__(self, num_processes, run_config={}):
        BaseThread.__init__(self)
        self.num_processes = num_processes
        self.run_config    = run_config
        
    def run(self):
        self.kill(False)
        self.is_running = True
        self.processes  = []
        for proc_idx in range(self.num_processes) :
            run_idx         = proc_idx + 1
            epoch_tracker   = Value('i', -1)
            p               = Process(target=run_experiment, args=(self.run_config, run_idx, False, epoch_tracker))
            p.epoch_tracker = epoch_tracker
            p.start()
            self.processes.append(p)
        for p in self.processes :
            p.join()
        self.is_running = False
       
    
class MonitorThread(BaseThread):
    
    def __init__(self, worker, interval=1):
        BaseThread.__init__(self)
        self.worker   = worker
        self.interval = interval
        
    def run(self):
        self.kill(False)
        while worker.is_running and not self.killed :
            self.info(f"\r{len(worker.processes)} processes running at epochs [{', '.join([str(i) for p.epoch_tracker.value in worker.processes])}] / {worker.run_config.get('max_epochs',np.nan)}")
            time.sleep(self.interval)
        self.info("\n")
        
    

In [25]:
## Configure

run_config_baseline = dict(
    max_mse_true               = 0.05        ,
    max_epochs                 = 400         ,
    batch_size                 = 10          ,
    learning_rate              = 1e-2        ,
    test_after_epochs          = 1           ,
    plot_estimate_after_epochs = 5           ,
    plot_monitors_after_epochs = 10          ,
    save_objects_after_epochs  = 10          ,
    switch_after_epochs        = -1          ,
    clone_after_epochs         = 5           ,
    priority_weight            = 0.3         ,
    bootstrap_method           = "clone"     ,       # ["clone", "self", "other"]
    num_step_returns           = 1           ,
    run_tag                    = "baseline"
)

run_config_multistep = dict(
    max_mse_true               = 0.05        ,
    max_epochs                 = 400         ,
    batch_size                 = 10          ,
    learning_rate              = 1e-2        ,
    test_after_epochs          = 1           ,
    plot_estimate_after_epochs = 5           ,
    plot_monitors_after_epochs = 10          ,
    save_objects_after_epochs  = 10          ,
    switch_after_epochs        = -1          ,
    clone_after_epochs         = 5           ,
    priority_weight            = 0.3         ,
    bootstrap_method           = "clone"     ,       # ["clone", "self", "other"]
    num_step_returns           = 3           ,
    run_tag                    = "multistep"
)

run_config_no_priority = dict(
    max_mse_true               = 0.05        ,
    max_epochs                 = 400         ,
    batch_size                 = 10          ,
    learning_rate              = 1e-2        ,
    test_after_epochs          = 1           ,
    plot_estimate_after_epochs = 5           ,
    plot_monitors_after_epochs = 10          ,
    save_objects_after_epochs  = 10          ,
    switch_after_epochs        = -1          ,
    clone_after_epochs         = 5           ,
    priority_weight            = -1          ,
    bootstrap_method           = "clone"     ,       # ["clone", "self", "other"]
    num_step_returns           = 1           ,
    run_tag                    = "no_priority"
)

run_config_self_bootstrap = dict(
    max_mse_true               = 0.05        ,
    max_epochs                 = 400         ,
    batch_size                 = 10          ,
    learning_rate              = 1e-2        ,
    test_after_epochs          = 1           ,
    plot_estimate_after_epochs = 5           ,
    plot_monitors_after_epochs = 10          ,
    save_objects_after_epochs  = 10          ,
    switch_after_epochs        = -1          ,
    clone_after_epochs         = -1          ,
    priority_weight            = 0.3         ,
    bootstrap_method           = "self"      ,       # ["clone", "self", "other"]
    num_step_returns           = 1           ,
    run_tag                    = "self_bootstrap"
)

run_config_big_batch = dict(
    max_mse_true               = 0.05        ,
    max_epochs                 = 400         ,
    batch_size                 = 999         ,
    learning_rate              = 1e-2        ,
    test_after_epochs          = 1           ,
    plot_estimate_after_epochs = 5           ,
    plot_monitors_after_epochs = 10          ,
    save_objects_after_epochs  = 10          ,
    switch_after_epochs        = -1          ,
    clone_after_epochs         = 5           ,
    priority_weight            = 0.3         ,
    bootstrap_method           = "clone"     ,       # ["clone", "self", "other"]
    num_step_returns           = 1           ,
    run_tag                    = "big_batch"
)

run_config_standard_Q = dict(
    max_mse_true               = 0.05        ,
    max_epochs                 = 150         ,
    batch_size                 = 1           ,
    learning_rate              = 1e-2        ,
    test_after_epochs          = 1           ,
    plot_estimate_after_epochs = 1           ,
    plot_monitors_after_epochs = 10          ,
    save_objects_after_epochs  = 10          ,
    switch_after_epochs        = -1          ,
    clone_after_epochs         = 1           ,
    priority_weight            = 0.          ,
    bootstrap_method           = "clone"     ,       # ["clone", "self", "other"]
    num_step_returns           = 1           ,
    run_tag                    = "standard_Q"
)

run_config_clone_20 = dict(
    max_mse_true               = 0.05        ,
    max_epochs                 = 400         ,
    batch_size                 = 10          ,
    learning_rate              = 1e-2        ,
    test_after_epochs          = 1           ,
    plot_estimate_after_epochs = 5           ,
    plot_monitors_after_epochs = 10          ,
    save_objects_after_epochs  = 10          ,
    switch_after_epochs        = -1          ,
    clone_after_epochs         = 20          ,
    priority_weight            = 0.3         ,
    bootstrap_method           = "clone"     ,       # ["clone", "self", "other"]
    num_step_returns           = 1           ,
    run_tag                    = "clone_20"
)


In [None]:

num_runs = 3

for run_idx in range(num_runs) :
    run_experiment(run_config_baseline, run_idx+1, verbose=True)

for run_idx in range(num_runs) :
    run_experiment(run_config_multistep, run_idx+1, verbose=True)

for run_idx in range(num_runs) :
    run_experiment(run_config_no_priority, run_idx+1, verbose=True)

for run_idx in range(num_runs) :
    run_experiment(run_config_self_bootstrap, run_idx+1, verbose=True)

for run_idx in range(num_runs) :
    run_experiment(run_config_big_batch, run_idx+1, verbose=True)

for run_idx in range(num_runs) :
    run_experiment(run_config_standard_Q, run_idx+1, verbose=True)

for run_idx in range(num_runs) :
    run_experiment(run_config_clone_20, run_idx+1, verbose=True)
    

In [23]:

num_runs = 3

for run_idx in range(num_runs) :
    run_experiment(run_config_standard_Q, run_idx+1, verbose=True)

Directory figures/Helicopter_NB0/standard_Q/experiment_1 created
Environment config:
> distance_to_end: 13
> game_board_pad_size: 3
> horizontal_size: 20
> reward_per_turn: -1.0
> lambda_dx: 0.0
> lambda_b: -1.0
> gamma: 1.0
> x_min: 0
> x_max: 19
> x_range: 19
> x_start: 3
> x_end: 16
> action_list: [-1  0  1]
Training config:
> Stop training when mse_true exceeds 0.05
> Using 1 step empirical returns
> Using bootstrap method: clone
> Using epochs of length None
> Updating gradient every batch of size 1
> Using optimizer_q1 <keras.optimizer_v2.gradient_descent.SGD object at 0x2bface8e0> with learning rate 0.01
> Using optimizer_q2 None with learning rate 0.01
> Swapping q1 and q2 every -1 epochs
> Cloning q2 from q1 every 1 epochs
> Assigning a weight of 0.0 to anchoring state/action pairs


  y_min   = np.nanmin([0, np.nanmin(true_q), np.nanmin(target_q), np.nanmin(q_model), np.nanmin(bs_model)])
  y_max   = np.nanmax([0, np.nanmax(true_q), np.nanmax(target_q), np.nanmax(q_model), np.nanmax(bs_model)])


Epoch 1 / 150  [t=5.40s]  <loss = 0.22965, ref_loss = 0.21672, max_Q = 1.1,  mse_true = 92.668>    
Epoch 2 / 150  [t=10.43s]  <loss = 0.19632, ref_loss = 0.83732, max_Q = 2.2,  mse_true = 76.534>   
Epoch 3 / 150  [t=14.66s]  <loss = 0.25019, ref_loss = 3.04245, max_Q = 3.1,  mse_true = 64.762>   
Epoch 4 / 150  [t=18.75s]  <loss = 0.39396, ref_loss = 7.95363, max_Q = 4.1,  mse_true = 53.640>   
Epoch 5 / 150  [t=22.99s]  <loss = 0.48947, ref_loss = 7.80177, max_Q = 4.4,  mse_true = 50.882>   
Epoch 6 / 150  [t=26.94s]  <loss = 0.77316, ref_loss = 11.40518, max_Q = 5.5,  mse_true = 41.840>  
Epoch 7 / 150  [t=31.15s]  <loss = 1.02149, ref_loss = 18.89093, max_Q = 6.7,  mse_true = 34.620>  
Epoch 8 / 150  [t=35.34s]  <loss = 1.05905, ref_loss = 16.29185, max_Q = 7.4,  mse_true = 28.567>  
Epoch 9 / 150  [t=39.38s]  <loss = 1.95633, ref_loss = 22.81470, max_Q = 9.6,  mse_true = 23.948>  
Epoch 10 / 150  [t=43.60s]  <loss = 2.17651, ref_loss = 31.20228, max_Q = 10.1,  mse_true = 20.863>


Epoch 82 / 150  [t=365.89s]  <loss = 13.99043, ref_loss = 295.71210, max_Q = 21.6,  mse_true = 111.959>
Epoch 83 / 150  [t=369.79s]  <loss = 13.28503, ref_loss = 275.31784, max_Q = 23.8,  mse_true = 108.034>
Epoch 84 / 150  [t=374.31s]  <loss = 15.09311, ref_loss = 315.02991, max_Q = 23.1,  mse_true = 122.039>
Epoch 85 / 150  [t=378.37s]  <loss = 14.14397, ref_loss = 245.77353, max_Q = 29.7,  mse_true = 93.786>
Epoch 86 / 150  [t=382.27s]  <loss = 15.57577, ref_loss = 242.14896, max_Q = 28.3,  mse_true = 82.935>
Epoch 87 / 150  [t=386.36s]  <loss = 13.37941, ref_loss = 288.66061, max_Q = 26.1,  mse_true = 110.531>
Epoch 88 / 150  [t=391.14s]  <loss = 16.77017, ref_loss = 296.79550, max_Q = 29.0,  mse_true = 121.984>
Epoch 89 / 150  [t=395.36s]  <loss = 23.88407, ref_loss = 280.62817, max_Q = 33.7,  mse_true = 126.434>
Epoch 90 / 150  [t=399.45s]  <loss = 19.51276, ref_loss = 408.51788, max_Q = 22.5,  mse_true = 175.838>
Epoch 91 / 150  [t=405.39s]  <loss = 17.40539, ref_loss = 452.3045

  y_min   = np.nanmin([0, np.nanmin(true_q), np.nanmin(target_q), np.nanmin(q_model), np.nanmin(bs_model)])
  y_max   = np.nanmax([0, np.nanmax(true_q), np.nanmax(target_q), np.nanmax(q_model), np.nanmax(bs_model)])


Epoch 1 / 150  [t=5.92s]  <loss = 0.22863, ref_loss = 0.21203, max_Q = 1.1,  mse_true = 92.668>    
Epoch 2 / 150  [t=9.80s]  <loss = 0.19426, ref_loss = 0.74614, max_Q = 2.2,  mse_true = 76.762>    
Epoch 3 / 150  [t=13.73s]  <loss = 0.27376, ref_loss = 3.05781, max_Q = 3.0,  mse_true = 64.406>   
Epoch 4 / 150  [t=17.66s]  <loss = 0.44308, ref_loss = 7.01379, max_Q = 4.2,  mse_true = 54.054>   
Epoch 5 / 150  [t=21.59s]  <loss = 0.72513, ref_loss = 11.57596, max_Q = 5.3,  mse_true = 45.259>  
Epoch 6 / 150  [t=25.45s]  <loss = 0.62002, ref_loss = 9.43670, max_Q = 5.6,  mse_true = 42.263>   
Epoch 7 / 150  [t=30.20s]  <loss = 1.08642, ref_loss = 16.96079, max_Q = 6.8,  mse_true = 33.663>  
Epoch 8 / 150  [t=34.15s]  <loss = 1.78150, ref_loss = 23.89141, max_Q = 8.5,  mse_true = 25.185>  
Epoch 9 / 150  [t=38.08s]  <loss = 2.42103, ref_loss = 35.08368, max_Q = 9.8,  mse_true = 19.764>  
Epoch 10 / 150  [t=42.01s]  <loss = 2.70148, ref_loss = 43.45083, max_Q = 11.7,  mse_true = 14.858>


Epoch 81 / 150  [t=353.20s]  <loss = 22.22629, ref_loss = 558.59943, max_Q = 25.4,  mse_true = 274.542>
Epoch 82 / 150  [t=357.19s]  <loss = 22.07330, ref_loss = 568.92572, max_Q = 25.8,  mse_true = 281.066>
Epoch 83 / 150  [t=361.28s]  <loss = 22.76839, ref_loss = 588.19220, max_Q = 25.9,  mse_true = 287.427>
Epoch 84 / 150  [t=365.13s]  <loss = 21.30755, ref_loss = 586.79346, max_Q = 25.7,  mse_true = 277.084>
Epoch 85 / 150  [t=369.12s]  <loss = 20.40063, ref_loss = 535.49188, max_Q = 24.6,  mse_true = 259.647>
Epoch 86 / 150  [t=373.11s]  <loss = 21.01887, ref_loss = 563.38818, max_Q = 25.2,  mse_true = 269.997>
Epoch 87 / 150  [t=377.03s]  <loss = 21.02679, ref_loss = 561.64001, max_Q = 25.0,  mse_true = 269.426>
Epoch 88 / 150  [t=382.23s]  <loss = 21.03269, ref_loss = 558.53705, max_Q = 24.9,  mse_true = 270.315>
Epoch 89 / 150  [t=386.27s]  <loss = 21.12017, ref_loss = 581.29272, max_Q = 25.4,  mse_true = 275.055>
Epoch 90 / 150  [t=390.18s]  <loss = 20.88843, ref_loss = 558.60

  y_min   = np.nanmin([0, np.nanmin(true_q), np.nanmin(target_q), np.nanmin(q_model), np.nanmin(bs_model)])
  y_max   = np.nanmax([0, np.nanmax(true_q), np.nanmax(target_q), np.nanmax(q_model), np.nanmax(bs_model)])


Epoch 1 / 150  [t=4.84s]  <loss = 0.23666, ref_loss = 0.22415, max_Q = 1.0,  mse_true = 92.668>    
Epoch 2 / 150  [t=10.26s]  <loss = 0.19686, ref_loss = 0.69052, max_Q = 2.1,  mse_true = 77.398>   
Epoch 3 / 150  [t=14.25s]  <loss = 0.25955, ref_loss = 2.99233, max_Q = 2.9,  mse_true = 64.614>   
Epoch 4 / 150  [t=18.19s]  <loss = 0.42582, ref_loss = 6.17252, max_Q = 3.9,  mse_true = 54.676>   
Epoch 5 / 150  [t=22.04s]  <loss = 0.76105, ref_loss = 12.95489, max_Q = 5.3,  mse_true = 43.063>  
Epoch 6 / 150  [t=26.02s]  <loss = 1.16497, ref_loss = 16.73214, max_Q = 7.0,  mse_true = 35.010>  
Epoch 7 / 150  [t=29.98s]  <loss = 1.55624, ref_loss = 23.44584, max_Q = 8.1,  mse_true = 28.020>  
Epoch 8 / 150  [t=34.04s]  <loss = 1.93999, ref_loss = 29.31461, max_Q = 10.1,  mse_true = 22.111> 
Epoch 9 / 150  [t=38.01s]  <loss = 1.57146, ref_loss = 26.26852, max_Q = 9.9,  mse_true = 17.400>  
Epoch 10 / 150  [t=41.84s]  <loss = 1.79795, ref_loss = 27.13691, max_Q = 11.7,  mse_true = 13.494>


Epoch 81 / 150  [t=355.02s]  <loss = 21.16934, ref_loss = 500.09967, max_Q = 24.5,  mse_true = 236.479>
Epoch 82 / 150  [t=359.14s]  <loss = 23.02880, ref_loss = 541.82043, max_Q = 25.7,  mse_true = 265.092>
Epoch 83 / 150  [t=363.17s]  <loss = 26.89902, ref_loss = 605.57422, max_Q = 26.6,  mse_true = 302.052>
Epoch 84 / 150  [t=367.23s]  <loss = 21.34974, ref_loss = 517.84253, max_Q = 25.1,  mse_true = 247.646>
Epoch 85 / 150  [t=371.24s]  <loss = 22.62323, ref_loss = 541.81781, max_Q = 25.5,  mse_true = 265.534>
Epoch 86 / 150  [t=377.44s]  <loss = 24.56201, ref_loss = 589.80304, max_Q = 26.5,  mse_true = 297.934>
Epoch 87 / 150  [t=381.54s]  <loss = 27.69460, ref_loss = 626.10663, max_Q = 27.5,  mse_true = 325.604>
Epoch 88 / 150  [t=385.58s]  <loss = 27.23139, ref_loss = 678.50464, max_Q = 28.5,  mse_true = 360.765>
Epoch 89 / 150  [t=389.75s]  <loss = 30.72142, ref_loss = 730.92523, max_Q = 29.7,  mse_true = 400.206>
Epoch 90 / 150  [t=393.87s]  <loss = 36.87233, ref_loss = 776.02

In [26]:

num_runs = 3

for run_idx in range(num_runs) :
    run_experiment(run_config_clone_20, run_idx+1, verbose=True)
    

Directory figures/Helicopter_NB0/clone_20 created
Directory figures/Helicopter_NB0/clone_20/experiment_1 created
Environment config:
> distance_to_end: 13
> game_board_pad_size: 3
> horizontal_size: 20
> reward_per_turn: -1.0
> lambda_dx: 0.0
> lambda_b: -1.0
> gamma: 1.0
> x_min: 0
> x_max: 19
> x_range: 19
> x_start: 3
> x_end: 16
> action_list: [-1  0  1]
Training config:
> Stop training when mse_true exceeds 0.05
> Using 1 step empirical returns
> Using bootstrap method: clone
> Using epochs of length None
> Updating gradient every batch of size 10
> Using optimizer_q1 <keras.optimizer_v2.gradient_descent.SGD object at 0x36ca9e0a0> with learning rate 0.01
> Using optimizer_q2 None with learning rate 0.01
> Swapping q1 and q2 every -1 epochs
> Cloning q2 from q1 every 20 epochs
> Assigning a weight of 0.3 to anchoring state/action pairs


  y_min   = np.nanmin([0, np.nanmin(true_q), np.nanmin(target_q), np.nanmin(q_model), np.nanmin(bs_model)])
  y_max   = np.nanmax([0, np.nanmax(true_q), np.nanmax(target_q), np.nanmax(q_model), np.nanmax(bs_model)])


Epoch 1 / 400  [t=3.20s]  <loss = 0.82902, ref_loss = 0.98630, max_Q = 0.2,  mse_true = 92.668>    
Epoch 2 / 400  [t=5.13s]  <loss = 0.52303, ref_loss = 0.49684, max_Q = 0.4,  mse_true = 88.641>    
Epoch 3 / 400  [t=6.92s]  <loss = 0.30445, ref_loss = 0.22834, max_Q = 0.6,  mse_true = 85.549>    
Epoch 4 / 400  [t=8.71s]  <loss = 0.15454, ref_loss = 0.09193, max_Q = 0.8,  mse_true = 83.170>    
Epoch 5 / 400  [t=10.61s]  <loss = 0.09641, ref_loss = 0.03274, max_Q = 0.9,  mse_true = 81.458>   
Epoch 6 / 400  [t=12.58s]  <loss = 0.06741, ref_loss = 0.00950, max_Q = 1.0,  mse_true = 80.182>   
Epoch 7 / 400  [t=14.41s]  <loss = 0.04173, ref_loss = 0.00198, max_Q = 1.0,  mse_true = 79.250>   
Epoch 8 / 400  [t=16.23s]  <loss = 0.03525, ref_loss = 0.00159, max_Q = 1.1,  mse_true = 78.699>   
Epoch 9 / 400  [t=17.99s]  <loss = 0.03793, ref_loss = 0.00220, max_Q = 1.1,  mse_true = 78.317>   
Epoch 10 / 400  [t=19.77s]  <loss = 0.03058, ref_loss = 0.00323, max_Q = 1.1,  mse_true = 78.002>  


Epoch 83 / 400  [t=171.52s]  <loss = 0.28792, ref_loss = 0.43961, max_Q = 5.0,  mse_true = 38.345> 
Epoch 84 / 400  [t=173.18s]  <loss = 0.27953, ref_loss = 0.37422, max_Q = 5.1,  mse_true = 37.307> 
Epoch 85 / 400  [t=174.79s]  <loss = 0.27526, ref_loss = 0.33422, max_Q = 5.1,  mse_true = 36.937> 
Epoch 86 / 400  [t=176.66s]  <loss = 0.26036, ref_loss = 0.33558, max_Q = 5.2,  mse_true = 36.341> 
Epoch 87 / 400  [t=178.62s]  <loss = 0.26354, ref_loss = 0.31001, max_Q = 5.2,  mse_true = 36.634> 
Epoch 88 / 400  [t=180.24s]  <loss = 0.28019, ref_loss = 0.26474, max_Q = 5.2,  mse_true = 36.434> 
Epoch 89 / 400  [t=181.94s]  <loss = 0.29359, ref_loss = 0.26872, max_Q = 5.2,  mse_true = 35.956> 
Epoch 90 / 400  [t=183.59s]  <loss = 0.26468, ref_loss = 0.30872, max_Q = 5.2,  mse_true = 35.622> 
Epoch 91 / 400  [t=187.52s]  <loss = 0.24943, ref_loss = 0.28426, max_Q = 5.3,  mse_true = 36.450> 
Epoch 92 / 400  [t=189.26s]  <loss = 0.27196, ref_loss = 0.22915, max_Q = 5.2,  mse_true = 36.825> 


Epoch 165 / 400  [t=331.73s]  <loss = 0.15231, ref_loss = 0.09570, max_Q = 9.2,  mse_true = 14.688>
Epoch 166 / 400  [t=334.08s]  <loss = 0.14923, ref_loss = 0.08827, max_Q = 9.3,  mse_true = 14.423>
Epoch 167 / 400  [t=335.86s]  <loss = 0.15019, ref_loss = 0.09478, max_Q = 9.2,  mse_true = 14.794>
Epoch 168 / 400  [t=337.87s]  <loss = 0.14690, ref_loss = 0.08986, max_Q = 9.3,  mse_true = 14.319>
Epoch 169 / 400  [t=339.57s]  <loss = 0.15089, ref_loss = 0.10350, max_Q = 9.3,  mse_true = 13.969>
Epoch 170 / 400  [t=341.34s]  <loss = 0.14816, ref_loss = 0.09842, max_Q = 9.2,  mse_true = 14.523>
Epoch 171 / 400  [t=345.23s]  <loss = 0.16571, ref_loss = 0.07257, max_Q = 9.3,  mse_true = 14.874>
Epoch 172 / 400  [t=347.13s]  <loss = 0.17086, ref_loss = 0.09716, max_Q = 9.2,  mse_true = 14.386>
Epoch 173 / 400  [t=348.79s]  <loss = 0.14291, ref_loss = 0.13146, max_Q = 9.2,  mse_true = 13.993>
Epoch 174 / 400  [t=350.94s]  <loss = 0.14685, ref_loss = 0.08444, max_Q = 9.2,  mse_true = 14.027>


Epoch 247 / 400  [t=492.02s]  <loss = 0.08683, ref_loss = 0.05526, max_Q = 12.7,  mse_true = 4.012>
Epoch 248 / 400  [t=493.64s]  <loss = 0.08693, ref_loss = 0.03704, max_Q = 12.8,  mse_true = 4.247>
Epoch 249 / 400  [t=495.26s]  <loss = 0.09626, ref_loss = 0.03766, max_Q = 12.8,  mse_true = 4.186>
Epoch 250 / 400  [t=497.56s]  <loss = 0.08511, ref_loss = 0.04489, max_Q = 12.7,  mse_true = 4.014>
Epoch 251 / 400  [t=500.81s]  <loss = 0.08260, ref_loss = 0.04378, max_Q = 12.7,  mse_true = 4.211>
Epoch 252 / 400  [t=502.43s]  <loss = 0.08638, ref_loss = 0.04233, max_Q = 12.7,  mse_true = 3.956>
Epoch 253 / 400  [t=504.65s]  <loss = 0.08532, ref_loss = 0.04161, max_Q = 12.7,  mse_true = 3.739>
Epoch 254 / 400  [t=506.28s]  <loss = 0.08195, ref_loss = 0.04312, max_Q = 12.7,  mse_true = 3.865>
Epoch 255 / 400  [t=507.91s]  <loss = 0.08235, ref_loss = 0.05022, max_Q = 12.7,  mse_true = 4.051>
Epoch 256 / 400  [t=509.72s]  <loss = 0.08741, ref_loss = 0.02290, max_Q = 12.7,  mse_true = 4.357>


Epoch 329 / 400  [t=649.65s]  <loss = 0.03582, ref_loss = 0.01785, max_Q = 15.2,  mse_true = 0.989>
Epoch 330 / 400  [t=651.28s]  <loss = 0.03696, ref_loss = 0.01183, max_Q = 15.3,  mse_true = 1.034>
Epoch 331 / 400  [t=654.61s]  <loss = 0.03914, ref_loss = 0.01476, max_Q = 15.2,  mse_true = 0.905>
Epoch 332 / 400  [t=657.09s]  <loss = 0.04761, ref_loss = 0.01047, max_Q = 15.2,  mse_true = 0.843>
Epoch 333 / 400  [t=658.72s]  <loss = 0.03829, ref_loss = 0.01361, max_Q = 15.3,  mse_true = 0.783>
Epoch 334 / 400  [t=660.36s]  <loss = 0.03563, ref_loss = 0.01279, max_Q = 15.3,  mse_true = 1.008>
Epoch 335 / 400  [t=662.00s]  <loss = 0.04246, ref_loss = 0.00730, max_Q = 15.3,  mse_true = 0.972>
Epoch 336 / 400  [t=663.80s]  <loss = 0.03402, ref_loss = 0.02010, max_Q = 15.1,  mse_true = 1.009>
Epoch 337 / 400  [t=665.42s]  <loss = 0.03594, ref_loss = 0.00843, max_Q = 15.3,  mse_true = 0.940>
Epoch 338 / 400  [t=667.04s]  <loss = 0.03358, ref_loss = 0.01694, max_Q = 15.2,  mse_true = 0.906>


  y_min   = np.nanmin([0, np.nanmin(true_q), np.nanmin(target_q), np.nanmin(q_model), np.nanmin(bs_model)])
  y_max   = np.nanmax([0, np.nanmax(true_q), np.nanmax(target_q), np.nanmax(q_model), np.nanmax(bs_model)])


Epoch 1 / 400  [t=2.72s]  <loss = 0.83307, ref_loss = 0.98365, max_Q = 0.2,  mse_true = 92.668>    
Epoch 2 / 400  [t=4.34s]  <loss = 0.51306, ref_loss = 0.48820, max_Q = 0.4,  mse_true = 88.620>    
Epoch 3 / 400  [t=7.07s]  <loss = 0.28832, ref_loss = 0.22524, max_Q = 0.6,  mse_true = 85.554>    
Epoch 4 / 400  [t=8.70s]  <loss = 0.15806, ref_loss = 0.09328, max_Q = 0.8,  mse_true = 83.234>    
Epoch 5 / 400  [t=10.35s]  <loss = 0.09012, ref_loss = 0.03287, max_Q = 0.9,  mse_true = 81.507>   
Epoch 6 / 400  [t=12.18s]  <loss = 0.05896, ref_loss = 0.00945, max_Q = 1.0,  mse_true = 80.258>   
Epoch 7 / 400  [t=13.81s]  <loss = 0.04327, ref_loss = 0.00242, max_Q = 1.0,  mse_true = 79.385>   
Epoch 8 / 400  [t=15.43s]  <loss = 0.03626, ref_loss = 0.00137, max_Q = 1.1,  mse_true = 78.790>   
Epoch 9 / 400  [t=17.05s]  <loss = 0.03234, ref_loss = 0.00187, max_Q = 1.1,  mse_true = 78.360>   
Epoch 10 / 400  [t=18.66s]  <loss = 0.03615, ref_loss = 0.00289, max_Q = 1.1,  mse_true = 78.062>  


Epoch 83 / 400  [t=157.47s]  <loss = 0.28499, ref_loss = 0.42299, max_Q = 5.0,  mse_true = 38.912> 
Epoch 84 / 400  [t=159.10s]  <loss = 0.27181, ref_loss = 0.39955, max_Q = 5.1,  mse_true = 37.740> 
Epoch 85 / 400  [t=160.72s]  <loss = 0.29399, ref_loss = 0.29123, max_Q = 5.1,  mse_true = 37.737> 
Epoch 86 / 400  [t=163.88s]  <loss = 0.26006, ref_loss = 0.33968, max_Q = 5.2,  mse_true = 36.648> 
Epoch 87 / 400  [t=165.49s]  <loss = 0.25523, ref_loss = 0.30478, max_Q = 5.2,  mse_true = 37.069> 
Epoch 88 / 400  [t=167.12s]  <loss = 0.25439, ref_loss = 0.29537, max_Q = 5.2,  mse_true = 37.017> 
Epoch 89 / 400  [t=168.76s]  <loss = 0.25982, ref_loss = 0.22665, max_Q = 5.2,  mse_true = 37.169> 
Epoch 90 / 400  [t=170.42s]  <loss = 0.25175, ref_loss = 0.26405, max_Q = 5.2,  mse_true = 36.482> 
Epoch 91 / 400  [t=173.44s]  <loss = 0.24994, ref_loss = 0.26437, max_Q = 5.2,  mse_true = 36.788> 
Epoch 92 / 400  [t=175.09s]  <loss = 0.27248, ref_loss = 0.23773, max_Q = 5.2,  mse_true = 37.005> 


Epoch 165 / 400  [t=313.08s]  <loss = 0.15926, ref_loss = 0.09659, max_Q = 9.5,  mse_true = 13.434>
Epoch 166 / 400  [t=314.87s]  <loss = 0.15395, ref_loss = 0.10152, max_Q = 9.4,  mse_true = 13.713>
Epoch 167 / 400  [t=316.47s]  <loss = 0.16365, ref_loss = 0.10455, max_Q = 9.4,  mse_true = 13.525>
Epoch 168 / 400  [t=318.06s]  <loss = 0.15378, ref_loss = 0.09327, max_Q = 9.4,  mse_true = 13.948>
Epoch 169 / 400  [t=319.65s]  <loss = 0.15235, ref_loss = 0.09967, max_Q = 9.3,  mse_true = 13.912>
Epoch 170 / 400  [t=321.25s]  <loss = 0.14776, ref_loss = 0.10525, max_Q = 9.4,  mse_true = 13.846>
Epoch 171 / 400  [t=324.35s]  <loss = 0.14641, ref_loss = 0.09415, max_Q = 9.4,  mse_true = 13.619>
Epoch 172 / 400  [t=325.96s]  <loss = 0.15017, ref_loss = 0.08122, max_Q = 9.3,  mse_true = 14.106>
Epoch 173 / 400  [t=327.57s]  <loss = 0.14837, ref_loss = 0.09262, max_Q = 9.4,  mse_true = 13.778>
Epoch 174 / 400  [t=329.18s]  <loss = 0.14304, ref_loss = 0.10072, max_Q = 9.4,  mse_true = 13.922>


Epoch 247 / 400  [t=468.20s]  <loss = 0.10154, ref_loss = 0.04559, max_Q = 13.2,  mse_true = 3.229>
Epoch 248 / 400  [t=469.81s]  <loss = 0.09090, ref_loss = 0.06504, max_Q = 13.3,  mse_true = 3.202>
Epoch 249 / 400  [t=471.44s]  <loss = 0.09406, ref_loss = 0.03988, max_Q = 13.1,  mse_true = 3.448>
Epoch 250 / 400  [t=473.06s]  <loss = 0.09046, ref_loss = 0.05209, max_Q = 13.1,  mse_true = 3.381>
Epoch 251 / 400  [t=476.25s]  <loss = 0.08799, ref_loss = 0.04005, max_Q = 13.3,  mse_true = 3.177>
Epoch 252 / 400  [t=477.86s]  <loss = 0.10325, ref_loss = 0.03806, max_Q = 13.1,  mse_true = 3.648>
Epoch 253 / 400  [t=479.47s]  <loss = 0.08980, ref_loss = 0.04794, max_Q = 13.1,  mse_true = 3.293>
Epoch 254 / 400  [t=481.08s]  <loss = 0.08704, ref_loss = 0.05465, max_Q = 13.1,  mse_true = 3.315>
Epoch 255 / 400  [t=482.69s]  <loss = 0.08438, ref_loss = 0.05235, max_Q = 13.2,  mse_true = 3.449>
Epoch 256 / 400  [t=484.48s]  <loss = 0.08772, ref_loss = 0.03935, max_Q = 13.2,  mse_true = 3.368>


Epoch 329 / 400  [t=623.49s]  <loss = 0.04817, ref_loss = 0.02208, max_Q = 15.8,  mse_true = 0.664>
Epoch 330 / 400  [t=625.09s]  <loss = 0.04205, ref_loss = 0.02159, max_Q = 15.8,  mse_true = 0.751>
Epoch 331 / 400  [t=628.37s]  <loss = 0.03788, ref_loss = 0.01657, max_Q = 15.7,  mse_true = 0.685>
Epoch 332 / 400  [t=629.98s]  <loss = 0.04146, ref_loss = 0.01106, max_Q = 15.8,  mse_true = 0.714>
Epoch 333 / 400  [t=631.59s]  <loss = 0.03990, ref_loss = 0.01596, max_Q = 15.7,  mse_true = 0.665>
Epoch 334 / 400  [t=633.19s]  <loss = 0.03679, ref_loss = 0.02043, max_Q = 15.7,  mse_true = 0.653>
Epoch 335 / 400  [t=634.80s]  <loss = 0.03582, ref_loss = 0.01316, max_Q = 15.8,  mse_true = 0.619>
Epoch 336 / 400  [t=636.59s]  <loss = 0.03899, ref_loss = 0.01443, max_Q = 15.7,  mse_true = 0.656>
Epoch 337 / 400  [t=638.18s]  <loss = 0.05412, ref_loss = 0.01453, max_Q = 15.7,  mse_true = 0.818>
Epoch 338 / 400  [t=639.77s]  <loss = 0.03937, ref_loss = 0.01271, max_Q = 15.7,  mse_true = 0.744>


  y_min   = np.nanmin([0, np.nanmin(true_q), np.nanmin(target_q), np.nanmin(q_model), np.nanmin(bs_model)])
  y_max   = np.nanmax([0, np.nanmax(true_q), np.nanmax(target_q), np.nanmax(q_model), np.nanmax(bs_model)])


Epoch 1 / 400  [t=2.70s]  <loss = 0.82893, ref_loss = 0.97736, max_Q = 0.2,  mse_true = 92.668>    
Epoch 2 / 400  [t=4.31s]  <loss = 0.50334, ref_loss = 0.49446, max_Q = 0.4,  mse_true = 88.623>    
Epoch 3 / 400  [t=5.91s]  <loss = 0.29315, ref_loss = 0.22970, max_Q = 0.6,  mse_true = 85.601>    
Epoch 4 / 400  [t=7.54s]  <loss = 0.15975, ref_loss = 0.09492, max_Q = 0.8,  mse_true = 83.265>    
Epoch 5 / 400  [t=9.17s]  <loss = 0.09153, ref_loss = 0.03479, max_Q = 0.9,  mse_true = 81.528>    
Epoch 6 / 400  [t=10.96s]  <loss = 0.05771, ref_loss = 0.00951, max_Q = 1.0,  mse_true = 80.245>   
Epoch 7 / 400  [t=12.56s]  <loss = 0.04908, ref_loss = 0.00250, max_Q = 1.0,  mse_true = 79.381>   
Epoch 8 / 400  [t=14.20s]  <loss = 0.03518, ref_loss = 0.00127, max_Q = 1.1,  mse_true = 78.732>   
Epoch 9 / 400  [t=18.10s]  <loss = 0.03216, ref_loss = 0.00192, max_Q = 1.1,  mse_true = 78.328>   
Epoch 10 / 400  [t=19.72s]  <loss = 0.03045, ref_loss = 0.00278, max_Q = 1.1,  mse_true = 78.041>  


Epoch 83 / 400  [t=161.06s]  <loss = 0.30188, ref_loss = 0.44335, max_Q = 5.1,  mse_true = 38.347> 
Epoch 84 / 400  [t=162.72s]  <loss = 0.27168, ref_loss = 0.41853, max_Q = 5.2,  mse_true = 37.045> 
Epoch 85 / 400  [t=164.74s]  <loss = 0.26897, ref_loss = 0.33936, max_Q = 5.2,  mse_true = 37.206> 
Epoch 86 / 400  [t=166.58s]  <loss = 0.26748, ref_loss = 0.28345, max_Q = 5.2,  mse_true = 37.148> 
Epoch 87 / 400  [t=168.28s]  <loss = 0.28551, ref_loss = 0.29237, max_Q = 5.2,  mse_true = 36.385> 
Epoch 88 / 400  [t=170.08s]  <loss = 0.25697, ref_loss = 0.29656, max_Q = 5.3,  mse_true = 36.348> 
Epoch 89 / 400  [t=171.87s]  <loss = 0.25535, ref_loss = 0.26134, max_Q = 5.2,  mse_true = 36.890> 
Epoch 90 / 400  [t=173.67s]  <loss = 0.26122, ref_loss = 0.25696, max_Q = 5.2,  mse_true = 36.288> 
Epoch 91 / 400  [t=179.53s]  <loss = 0.25260, ref_loss = 0.26852, max_Q = 5.2,  mse_true = 36.199> 
Epoch 92 / 400  [t=181.32s]  <loss = 0.25291, ref_loss = 0.23738, max_Q = 5.2,  mse_true = 36.853> 


Epoch 165 / 400  [t=328.65s]  <loss = 0.16307, ref_loss = 0.11132, max_Q = 9.3,  mse_true = 14.154>
Epoch 166 / 400  [t=330.62s]  <loss = 0.16654, ref_loss = 0.12072, max_Q = 9.3,  mse_true = 13.424>
Epoch 167 / 400  [t=332.40s]  <loss = 0.15292, ref_loss = 0.12158, max_Q = 9.4,  mse_true = 13.915>
Epoch 168 / 400  [t=334.18s]  <loss = 0.15506, ref_loss = 0.09682, max_Q = 9.3,  mse_true = 14.187>
Epoch 169 / 400  [t=335.97s]  <loss = 0.15411, ref_loss = 0.10312, max_Q = 9.3,  mse_true = 13.992>
Epoch 170 / 400  [t=337.73s]  <loss = 0.15357, ref_loss = 0.10294, max_Q = 9.3,  mse_true = 14.571>
Epoch 171 / 400  [t=341.00s]  <loss = 0.16166, ref_loss = 0.10126, max_Q = 9.3,  mse_true = 14.311>
Epoch 172 / 400  [t=342.73s]  <loss = 0.15326, ref_loss = 0.10345, max_Q = 9.3,  mse_true = 14.571>
Epoch 173 / 400  [t=344.48s]  <loss = 0.14918, ref_loss = 0.11025, max_Q = 9.3,  mse_true = 13.958>
Epoch 174 / 400  [t=346.24s]  <loss = 0.16670, ref_loss = 0.08593, max_Q = 9.3,  mse_true = 14.356>


Epoch 247 / 400  [t=492.73s]  <loss = 0.09316, ref_loss = 0.05532, max_Q = 12.6,  mse_true = 4.137>
Epoch 248 / 400  [t=494.38s]  <loss = 0.08830, ref_loss = 0.06553, max_Q = 12.6,  mse_true = 4.246>
Epoch 249 / 400  [t=496.03s]  <loss = 0.08834, ref_loss = 0.04037, max_Q = 12.7,  mse_true = 4.348>
Epoch 250 / 400  [t=497.67s]  <loss = 0.09259, ref_loss = 0.05055, max_Q = 12.6,  mse_true = 4.738>
Epoch 251 / 400  [t=500.90s]  <loss = 0.09055, ref_loss = 0.04982, max_Q = 12.7,  mse_true = 4.520>
Epoch 252 / 400  [t=502.56s]  <loss = 0.08937, ref_loss = 0.04536, max_Q = 12.7,  mse_true = 4.590>
Epoch 253 / 400  [t=504.21s]  <loss = 0.08460, ref_loss = 0.03654, max_Q = 12.7,  mse_true = 4.247>
Epoch 254 / 400  [t=505.85s]  <loss = 0.08666, ref_loss = 0.04325, max_Q = 12.6,  mse_true = 4.296>
Epoch 255 / 400  [t=507.49s]  <loss = 0.08281, ref_loss = 0.04674, max_Q = 12.7,  mse_true = 4.305>
Epoch 256 / 400  [t=509.32s]  <loss = 0.09176, ref_loss = 0.03233, max_Q = 12.6,  mse_true = 4.578>


Epoch 329 / 400  [t=652.35s]  <loss = 0.05147, ref_loss = 0.01944, max_Q = 15.7,  mse_true = 0.694>
Epoch 330 / 400  [t=653.96s]  <loss = 0.04172, ref_loss = 0.01584, max_Q = 15.9,  mse_true = 0.639>
Epoch 331 / 400  [t=657.66s]  <loss = 0.04019, ref_loss = 0.01334, max_Q = 15.7,  mse_true = 0.691>
Epoch 332 / 400  [t=659.32s]  <loss = 0.05480, ref_loss = 0.01942, max_Q = 16.1,  mse_true = 0.617>
Epoch 333 / 400  [t=660.99s]  <loss = 0.04231, ref_loss = 0.01033, max_Q = 15.7,  mse_true = 0.681>
Epoch 334 / 400  [t=662.63s]  <loss = 0.03993, ref_loss = 0.01790, max_Q = 15.9,  mse_true = 0.648>
Epoch 335 / 400  [t=664.26s]  <loss = 0.03658, ref_loss = 0.01248, max_Q = 15.8,  mse_true = 0.663>
Epoch 336 / 400  [t=666.09s]  <loss = 0.03749, ref_loss = 0.01258, max_Q = 15.8,  mse_true = 0.655>
Epoch 337 / 400  [t=667.70s]  <loss = 0.03609, ref_loss = 0.01381, max_Q = 15.8,  mse_true = 0.656>
Epoch 338 / 400  [t=672.54s]  <loss = 0.03716, ref_loss = 0.01139, max_Q = 15.7,  mse_true = 0.650>
