In [1]:
#  Required imports

import math, os, sys, time

import numpy as np

from matplotlib import pyplot as plt

from scipy import stats

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import BatchNormalization, Conv2D, Concatenate, Dense, Dropout, Flatten, Input, MaxPooling2D, Rescaling
from tensorflow.keras.models import Model

print("TensorFlow has found devices:")
for device in tf.config.list_physical_devices() :
    print(f"-  {device}")
    
#  Use non-interactive backend to prevent memory leak from creation and non-deletion of background GUI objects
#  > interactive backends such as 'inline' cause plt.close(fig) to not release memory, causing massive data leak
#  > plt.clf() followed by plt.close(fig) still works okay but still small data leak from GUI objects
#  > non-interactive backend allows plt.close(fig) to work normally with no leak, beacuse no GUI objects are
#    created (this means we cannot use things like plt.show(fig))
%matplotlib agg
    



TensorFlow has found devices:
-  PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU')
-  PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')


In [2]:
###
###  Global constants
###

#  Initially we will just run on a game board of fixed size, to avoid building an architecture to handle 
#  variable board sizes, so let's configure this here

horizontal_size = 8
vertical_size   = 5
horizontal_pad  = 2
vertical_pad    = 2
horizontal_max  = horizontal_size + 2*horizontal_pad
vertical_max    = vertical_size   + 2*vertical_pad
r_norm          = max(horizontal_max, vertical_max)
storm_radius    = 2.
num_storms      = 2

reward_per_turn = -1.
lambda_r        = 0.
lambda_b        = -2.
lambda_w        = -15.
gamma           = .98

r_start = np.array([horizontal_pad, vertical_pad+vertical_size-1])
r_end   = np.array([horizontal_pad+horizontal_size-1, vertical_pad]) 

action_list = np.array([[-1, -1], [-1, 0], [-1, 1], [0, -1], [0, 0], [0, 1], [1, -1], [1, 0], [1, 1]])

print(f"r_start = {r_start}")
print(f"r_end   = {r_end}")


In [3]:
###
###  Define and unit-test methods used to calculate Lp measures and norms
###


def Lp_norm(v, p=2) :
    '''
    Calculate Lp norm of vector v, defined as [sum_i v_i^p]^(1/p). If np.isfinite(p) returns False then
    calculate the L-infinity norm, which just returns the highest mod-vector-component.
    Inputs:
      > v, np.ndarray of any shape >= 1D
        vector to calculate the norm of
      > p, float, default=2
        exponent of the Lp norm
    Return:
      > float, the Lp norm
    '''
    ##  If type(v) is not numpy array then try to cast it to one
    if type(v) != np.ndarray :
        v = np.array(v)
    ##  Return L-infinity norm if p is not a real number
    if not np.isfinite(p) :
        return np.fabs(v).max()
    ##  Use np.power method to return Lp norm if p is finite
    return np.power(np.power(v, p).sum(), 1./p)
  
    
def Lp_distance(v1, v2, p=2) :
    '''
    Calculate Lp distance between vectors v1 and v2 by calling Lp_norm(v2-v1).
    Inputs:
      > v1, np.ndarray of any shape >= 1D
        first vector
      > v2, np.ndarray of same shape as v1
        second vector
      > p, float, default=2
        exponent of the Lp-distance
    Return:
      > float, the Lp distance
    '''
    ##  If v1 or v2 are not numpy arrays then try to cast them
    if type(v1) != np.ndarray :
        v1 = np.array(v1)
    if type(v2) != np.ndarray :
        v2 = np.array(v2)
    ##  Return the norm of the difference between v1 and v2
    return Lp_norm(v2 - v1, p=p)


In [None]:
###
###  Define and unit-test environment methods
###


def create_weather_map(_num_storms=num_storms, _storm_radius=storm_radius) :
    x_storms = np.random.uniform(low=0, high=horizontal_max-1, size=(_num_storms,))
    y_storms = np.random.uniform(low=0, high=vertical_max  -1, size=(_num_storms,))
    r_storms = np.array([x_storms, y_storms]).transpose()
    weather_map = np.zeros(shape=(horizontal_max, vertical_max, 1))
    for x in range(horizontal_max) :
        for y in range(vertical_max) :
            r = np.array([x,y])
            intensity = 0
            for x_storm, y_storm in zip(x_storms, y_storms) :
                r_storm    = np.array([x_storm,y_storm])
                dr         = Lp_distance(r,r_storm)
                intensity += 1. / (1 + dr/_storm_radius)
            weather_map[x, y, 0] = intensity
    return weather_map


def perform_action(weather_map, r_agent, action, base_reward=reward_per_turn, boundary_reward=lambda_b, 
                   dr_reward_factor=lambda_r, w_reward_factor=lambda_w, verbose=False) :
    '''
    Given the current environment and agent states, perform the specified action and return the reward 
    obtaine along with the new agent state.
    Inputs:
      > weather_map, np.ndarray of shape (horizontal_max, vertical_max, 1)
        intensity of weather at every pixel in the map
      > r_agent, np.ndarray of shape (2,)
        (x,y) position of agent at initial timestep
      > action, np.ndarray of shape (2,)
        (dx,dy) of action to be performed, each component expected to be one of {-1, 0, +1}
      > base_reward, float, default=-2.
        basic reward returned every turn (expected -ve)
      > boundary_reward, float, default=lambda_b
        reward received when encountering the edge of the game board (expected -ve)
      > dr_reward_factor, float, default=lambda_r
        factor multiplied by change-in-distance to calculate movement reward (expected +ve)
    Returns:
      > float
        reward obtained by performing action
      > np.ndarray of shape (2,)
        (x,y) position of agent at iterated timestep
    '''
    ##  Make sure initial state is valid to protect against unexpected behaviour
    if is_terminal(r_agent) :
        raise RuntimeError(f"Agent state is terminal, so no actions may be performed")
    if is_out_of_bounds(r_agent) :
        raise RuntimeError(f"Agent state {r_agent} is out of bounds, so no actions may be performed")
    ##  Get initial distance of agent from the end
    d_agent = Lp_distance(r_agent, r_end)
    ##  Iterate agent position
    ##  - if agent hits a wall then add an appropriate penalty and return agent to original position
    r_agent_p = r_agent + action
    reward_b  = 0
    if r_agent_p[0] < 0 or r_agent_p[0] >= horizontal_max or r_agent_p[1] < 0 or r_agent_p[1] >= vertical_max :
        reward_b  = boundary_reward
        r_agent_p = r_agent.copy()
    ##  Get distance-based reward
    d_agent_p = Lp_distance(r_agent_p, r_end)
    reward_r  = dr_reward_factor * (d_agent - d_agent_p) / np.sqrt(2)
    ##  Get weather-based reward
    reward_w  = w_reward_factor * weather_map[int(r_agent[0]), int(r_agent[1]), 0]
    ##  Calculate total reward by summing the base, boundary, distance and weather rewards
    reward = 0. if is_terminal(r_agent_p) else base_reward + reward_b + reward_r + reward_w
    if verbose :
        print(f"perform_action: agent {r_agent} action {action} --> agent {r_agent_p} reward {reward:.2f}  [{base_reward:.2f} (base) + {reward_b:.2f} (b) + {reward_r:.2f} (r) + {reward_w:.2f} (w)]")
    ##  Return reward and new agent state
    return reward, r_agent_p


def is_terminal(r_agent) :
    '''
    Return True if the agent is in the terminal state and False otherwise.
    Inputs:
      > r_agent, np.ndarray of shape (2,)
        agent position as (x,y)-coordinates
    Returns:
      > bool
        whether the agent is in the terminal state
    '''
    if r_agent[0] == r_end[0] and r_agent[1] == r_end[1] :
        return True
    return False


def is_out_of_bounds(r_agent) :
    '''
    Return True if the agent isout of bounds and False otherwise.
    Inputs:
      > r_agent, np.ndarray of shape (2,)
        agent position as (x,y)-coordinates
    Returns:
      > bool
        whether the agent is out of bounds.
    '''
    if r_agent[0] <  0 : return True
    if r_agent[1] <  0 : return True
    if r_agent[0] >= horizontal_max : return True
    if r_agent[1] >= vertical_max   : return True
    return False
    

def get_greedy_action(weather_map, r_agent, *q_models) :
    '''
    Sample a greedy action from the Q-value models provided. If multiple models provided then use their mean.
    Inputs:
      > weather_map, np.ndarray of shape (horizontal_max, vertical_max)
        intensity of weather at every pixel in the map
      > r_agent, np.ndarray of shape (2,)
        agent position as (x,y)-coordinates
      > q_models, list of tf.keras Model class, each with inputs [agent position, action] = [Input(2), Input(2)]
        list of Keras Q(s,a) models
    Returns:
      > np.ndarray of shape (2,)
        action defined by greedy policy over the model(s) at this agent position
      > list of np.ndarray of shape (9,)
        action values in the same order as action_list, in list of models provided
    '''
    weather_maps        = np.array([weather_map for i in range(9)])
    r_agents            = np.array([r_agent for i in range(9)])
    model_args          = [weather_maps, r_agents, action_list]
    model_action_values = [model.predict(model_args) for model in q_models]
    action_values       = np.mean(model_action_values, axis=0)
    best_action         = action_list[np.argmax(action_values)]
    return best_action, model_action_values


def get_exploration_action(num=1) :
    '''
    Generate uniformly random actions from the 9 available.
    Input:
      > num, int, default=1
        number of random actions to generate
    Returns :
      > np.ndarray of size (num,2)
        list of actions generated
    '''
    return action_list[np.random.randint(low=0, high=8, size=(num,))]
    

In [None]:
###
###  Method for creating action-value model
###

def create_action_value_model(name=None) :
    '''
    Create a network for the action-value model.
    Inputs:
      > name, str, default=None
        model name, if None then keras default is used
    Returns:
      > keras Model: uncompiled keras model (must be trained using custom loop)
    '''
    input_layer_w = Input((horizontal_max, vertical_max, 1))
    input_layer_r = Input((2,))
    input_layer_a = Input((2,))
    
    next_layer_w  = Conv2D(20, kernel_size=(2,2), activation="relu")(input_layer_w)
    next_layer_w  = MaxPooling2D(pool_size=(2,2))(next_layer_w)
    next_layer_w  = BatchNormalization()(next_layer_w)
    next_layer_w  = Conv2D(20, kernel_size=(2,2), activation="relu")(next_layer_w)
    next_layer_w  = Flatten()(next_layer_w)
    next_layer_w  = BatchNormalization()(next_layer_w)
    next_layer_w  = Dense(50, activation="relu")(next_layer_w)
    next_layer_w  = BatchNormalization()(next_layer_w)
    
    next_layer_r  = Rescaling(scale=2./r_norm, offset=-0.5*r_norm)(input_layer_r)
    next_layer_r  = Dense(50, activation="relu")(next_layer_r)
    next_layer_r  = BatchNormalization()(next_layer_r)
    
    next_layer_a  = Dense(50, activation="relu")(input_layer_a)
    next_layer_a  = BatchNormalization()(next_layer_a)
    
    next_layer    = Concatenate()([next_layer_w, next_layer_r, next_layer_a])
    next_layer    = Dense(500, activation="relu")(next_layer)
    next_layer    = BatchNormalization()(next_layer)
    next_layer    = Dense(500, activation="linear")(next_layer)
    output_layer  = Dense(1, activation="linear")(next_layer)
    model         = Model([input_layer_w, input_layer_r, input_layer_a], output_layer, name=name)
    return model
    

In [None]:
def generate_directory_for_file_path(fname, print_msg_on_dir_creation=True) :
    """
    Create the directory structure needed to place file fname. Call this before fig.savefig(fname, ...) to 
    make sure fname can be created without a FileNotFoundError
    Input:
       - fname: str
                name of file you want to create a tree of directories to enclose
                also create directory at this path if fname ends in '/'
       - print_msg_on_dir_creation: bool, default = True
                                    if True then print a message whenever a new directory is created
    """
    while "//" in fname :
        fname = fname.replace("//", "/")
    dir_tree = fname.split("/")
    dir_tree = ["/".join(dir_tree[:i]) for i in range(1,len(dir_tree))]
    dir_path = ""
    for dir_path in dir_tree :
        if len(dir_path) == 0 : continue
        if not os.path.exists(dir_path) :
            os.mkdir(dir_path)
            if print_msg_on_dir_creation :
                print(f"Directory {dir_path} created")
            continue
        if os.path.isdir(dir_path) : 
            continue
        raise RuntimeError(f"Cannot create directory {dir_path} because it already exists and is not a directory")
    
    
def create_greedy_policy_plot(weather_map, *q_models, epoch_idx=-1, verbose=False, show=False, close=False, save="") :
    '''
    Create a plt.Figure instance visualising the greedy policy defined by the average of the q-value models 
    provided. Allows for plot to be shown, saved and/or closed using plt interface. Returns the plot figure
    and axis objects so they can continue to be manipulated, but note that objects will no longer be in scope
    if we have called plt.close(fig).
    Inputs:
      > weather_map, np.ndarray of size (horizontal_max, vertical_max)
        weather intensity at every state
      > q_models, list of keras Model class
        list of q-value models to define the greedy policy
      > epoch_idx, int, default=-1
        if positive then draw a text box displaying how many epochs have been performed
      > verbose, bool, default=False
        if True then print some text to display progress as we evaluate the models for every state/action pair
      > show, bool, default=False
        if True then call plt.show(fig)
      > close, bool, default=False
        if True then call plt.close(fig)
      > save, str, default=""
        if string provided then call fig.savefig(save, ...), creating any required subdirectories if needed
    Returns:
      > plt.Figure instance
      > plt.Axes instance
    '''
     
    #  Keep track of how long plotting takes, to help inform how often to call this function    
    start_time = time.time()
    
    #  Set up plot
    fig = plt.figure(figsize=(12*horizontal_max/11.5,12*vertical_max/11.5))
    ax  = fig.add_subplot(1, 1, 1)
    ax.set_xlim(-0.5, horizontal_max-0.5)
    ax.set_ylim(-0.5, vertical_max-0.5)
    ax.tick_params(axis="both", which="both", right=True, top=True, direction="in", labelsize=16)
       
    #  Draw weather!
    ax.imshow(weather_map[:,:,0].transpose(), origin="lower", alpha=0.5, cmap="Greys", vmin=0, vmax=1)
    
    #  Draw arrows by looping over states and finding greedy action according to q-models
    is_first_arrow = True
    for x in range(horizontal_max) :
        for y in range(vertical_max) :
            if x == r_end[0] and y == r_end[1] : continue
            if verbose :
                sys.stdout.write(f"\rEvaluating greedy policy for agent state ({x}, {y})".ljust(100))
            r_agent      = np.array([x,y])
            (dx, dy), qs = get_greedy_action(weather_map, r_agent, *q_models)
            q1s, q2s     = qs
            if dx == 0 and dy == 0 :
                ax.plot(x, y, "o", markersize=8, c="b", alpha=1)
            else :
                ax.arrow(x - 0.3*dx, y - 0.3*dy, 0.6*dx, 0.6*dy, head_width=0.25, length_includes_head=True, color="b")
                is_first_arrow = False
                
    #  Draw accompanying plot objects
    ax.fill_between([r_start[0]-0.5, r_start[0]+0.5], r_start[1]-0.5, r_start[1]+0.5, color="g", alpha=0.2, label="Start")
    ax.fill_between([r_end  [0]-0.5, r_end  [0]+0.5], r_end  [1]-0.5, r_end  [1]+0.5, color="r", alpha=0.2, label="Finish")
    ax.legend(loc=(0.,1.002), ncol=3, fontsize=16, frameon=False)
    
    #  Draw text boxes displaying title and num. epochs
    ax.text(0, 1.07, "Weather intensity and greedy policy per $s$", transform=ax.transAxes, 
            fontsize=18, weight="bold", ha="left", va="bottom")
    if epoch_idx >= 0 :
        ax.text(1, 1.01, f"After {epoch_idx} epochs", ha="right", va="bottom", weight="bold", 
                transform=ax.transAxes, fontsize=16)
    
    #  Verbose messaging
    '''if verbose :
        sys.stdout.write(f"\nPlot created in {time.time()-start_time:.2f}s".ljust(100)+"\n")'''
       
    #  Save / show / close
    if len(save) > 0 :
        generate_directory_for_file_path(save)
        plt.savefig(save, bbox_inches="tight")
    if show :
        plt.show(fig)
    if close :
        plt.clf()
        plt.close(fig)
        
    #  Return figure and axis
    return fig, ax


def draw_training_curve(ax, container, m, c, label) :
    ax.plot([x for x,y in container], [y for x,y in container], m, ms=7, c=c, alpha=1.0, label=label)
    for ((x1,y1), (x2,y2)) in zip(container[:-1], container[1:]) :
        if np.fabs(x2-x1) > 1.5 : continue
        ax.plot([x1, x2], [y1, y2], "-", c=c, lw=2, alpha=0.7)
            
            
def create_training_curves_plot(loss_record, ref_loss_record, maxQ_record, show=False, close=False, save="") :
    '''
    Create a plt.Figure instance visualising the training curves. Allows for plot to be shown, saved and/or 
    closed using plt interface. Returns the plot figure and axis objects so they can continue to be 
    manipulated, but note that objects will no longer be in scope if we have called plt.close(fig).
    Inputs:
      > q_models, list of keras Model class
        list of q-value models to define the greedy policy
      > verbose, bool, default=False
        if True then print some text to display progress as we evaluate the models for every state/action pair
      > show, bool, default=False
        if True then call plt.show(fig)
      > close, bool, default=False
        if True then call plt.close(fig)
      > save, str, default=""
        if string provided then call fig.savefig(save, ...), creating any required subdirectories if needed
    Returns:
      > plt.Figure instance
      > plt.Axes instance (axis corresponding to loss curves)
      > plt.Axes instance (axis corresponding to ref_loss curves)
      > plt.Axes instance (axis corresponding to maxQ curves)
    '''
            
    fig = plt.figure(figsize=(30,15))
    
    ax1 = fig.add_subplot(3, 1, 1)
    ax1.tick_params(axis="both", which="both", right=True, top=True, direction="in", labelsize=30)
    ax1.set_title(r"Mean loss per batch [$(1-\lambda)\cdot$batch + $\lambda\cdot$ref]", fontsize=30)
    ax1.xaxis.set_ticklabels([])
    draw_training_curve(ax1, loss_record["Q1"], "o", "r", "$q_1$")
    draw_training_curve(ax1, loss_record["Q2"], "x", "b", "$q_2$")
    ax1.set_yscale("log")
    ax1.legend(loc=(0,1.02), fontsize=30, ncol=3, title_fontsize=30)
    
    ax2 = fig.add_subplot(3, 1, 2)
    ax2.tick_params(axis="both", which="both", right=True, top=True, direction="in", labelsize=30)
    ax2.set_title(r"Mean loss per batch [ref only]", fontsize=30)
    ax2.xaxis.set_ticklabels([])
    draw_training_curve(ax2, ref_loss_record["Q1"], "o", "r", "$q_1$")
    draw_training_curve(ax2, ref_loss_record["Q2"], "x", "b", "$q_2$")
    ax2.set_yscale("log")
    
    ax3 = fig.add_subplot(3, 1, 3)
    ax3.tick_params(axis="both", which="both", right=True, top=True, direction="in", labelsize=30)
    ax3.set_title(r"Max $|q(s,a)|$ over all batches", fontsize=30)
    ax3.set_xlabel(r"Epoch", labelpad=15, fontsize=30)
    draw_training_curve(ax3, maxQ_record["Q1"], "o", "r", "$q_1$")
    draw_training_curve(ax3, maxQ_record["Q2"], "x", "b", "$q_2$")
    ax3.axhline(0, ls="--", lw=2, c="gray")
    
    fig.subplots_adjust(hspace=0.2)
    
    if len(save) > 0 :
        generate_directory_for_file_path(save)
        plt.savefig(save, bbox_inches="tight")
    if show :
        plt.show(fig)
    if close :
        plt.clf()
        plt.close(fig)
        
    return fig, ax1, ax2, ax3
    

In [None]:
###
###  Identify priority state/action pairs
###

priority_states, priority_actions, priority_returns = [], [], []
for action in action_list :
    r_initial = r_end - action
    if is_terminal(r_initial) : continue
    priority_states .append(r_initial)
    priority_actions.append(action)
    priority_returns.append(0.)
    
priority_states  = np.array(priority_states )
priority_actions = np.array(priority_actions)
priority_returns = np.array(priority_returns)

print(f"Found {len(priority_returns)} priority state-action pairs with returns:  {'  '.join([f'{x:.2f}' for x in priority_returns])}")


In [None]:
def create_config(config_fname, q1_model, q2_model, to_stdout=True) :
    '''
    Print environment, training and model configurations to file config_fname. Also print environment and
    training configurations to sys.stdout if requested, but do not print model summaries as they are verbose.
    Inputs:
      > config_fname, str
        name of config file to create
      > q1_model, keras Model
        first q-value model
      > q2_model, keras Model
        second q-value model
      > to_stdout, bool, default=True
        if True then repeat environment and training configurations to sys.stdout
    Returns:
      > None
    '''
    # Create message as list of strings
    config_message = []
    config_message.append(f"="*114 + "\n")
    config_message.append(f"Environment config:\n")
    config_message.append(f"> horizontal_size: {horizontal_size}\n")
    config_message.append(f"> vertical_size: {vertical_size}\n")
    config_message.append(f"> horizontal_pad: {horizontal_pad}\n")
    config_message.append(f"> vertical_pad: {vertical_pad}\n")
    config_message.append(f"> horizontal_max: {horizontal_max}\n")
    config_message.append(f"> vertical_max: {vertical_max}\n")
    config_message.append(f"> reward_per_turn: {reward_per_turn:.6f}\n")
    config_message.append(f"> lambda_r: {lambda_r:.6f}\n")
    config_message.append(f"> lambda_b: {lambda_b:.6f}\n")
    config_message.append(f"> lambda_w: {lambda_w:.6f}\n")
    config_message.append(f"> gamma: {gamma:.6f}\n")
    config_message.append(f"="*114 + "\n")
    config_message.append(f"Training config:\n")
    config_message.append(f"> Using bootstrap method: {bootstrap_method}\n")
    config_message.append(f"> Using epochs of length {num_train}\n")
    config_message.append(f"> Updating gradient every batch of size {batch_size}\n")
    config_message.append(f"> Using optimizer_q1 {optimizer_q1} with learning rate {learning_rate:.6}\n")
    config_message.append(f"> Using optimizer_q2 {optimizer_q2} with learning rate {learning_rate:.6}\n")
    config_message.append(f"> Plotting policy every {plot_policy_after_epochs} epochs\n")
    config_message.append(f"> Plotting monitors every {plot_monitors_after_epochs} epochs\n")
    config_message.append(f"> Swapping q1 and q2 every {switch_after_epochs} epochs\n")
    config_message.append(f"> Cloning q2 from q1 every {clone_after_epochs} epochs\n")
    config_message.append(f"> Assigning a weight of {priority_weight} to anchoring state/action pairs\n")
    config_message.append(f"="*114 + "\n")
    # Make sure directory exists for file
    generate_directory_for_file_path(config_fname, print_msg_on_dir_creation=True)
    # Open file and print messages, also to stdout if configured
    # - also print q-model summaries, only to file
    with open(config_fname, "w") as config_file :
        for line in config_message :
            config_file.write(line)
            if not to_stdout : continue
            sys.stdout.write(line)
        config_file.write("\nModel configs:\n\n")
        q1_model.summary(print_fn=lambda x: config_file.write(x + '\n'))
        config_file.write("\n")
        q2_model.summary(print_fn=lambda x: config_file.write(x + '\n'))



In [None]:

## Configure

num_epochs                 = -1
batch_size                 = 25
learning_rate              = 1e-4
plot_policy_after_epochs   = 10
plot_monitors_after_epochs = 2
switch_after_epochs        = -1
clone_after_epochs         = 2
priority_weight            = 0.3
bootstrap_method           = "clone"       # ["clone", "self", "other"]
num_emp_steps              = 1


## Set up

loss_fcn     = tf.keras.losses.MeanSquaredError()
optimizer_q1 = tf.keras.optimizers.SGD(learning_rate=learning_rate)
optimizer_q2 = tf.keras.optimizers.SGD(learning_rate=learning_rate)
q1_model     = create_action_value_model(name="action_value_model_1")
q2_model     = q1_model if bootstrap_method == "self" else create_action_value_model(name="action_value_model_2")
state_action_pairs = []
for x in range(horizontal_max) : 
    for y in range(vertical_max) : 
        if is_terminal(np.array([x,y])) : continue
        for a in action_list :
            state_action_pairs.append([x,y,a[0],a[1]])
np.random.shuffle(state_action_pairs)
state_action_pairs = np.array(state_action_pairs)
num_train = len(state_action_pairs)

if bootstrap_method not in ["clone", "self", "other"] :
    raise NotImplementedError(f"Bootstrap method {bootstrap_method} not implemented")


## Print config to file and screen (model summaries only to file because they are verbose)

create_config("figures/Helicopter_NB2/config.txt", q1_model, q2_model, to_stdout=True)


## Set up monitors for training start

model_key_q1, model_key_q2 = "Q1", "Q2"  # used to keep track of which model is being traing each epoch
loss_record, ref_loss_record, maxQ_record = {"Q1":[], "Q2":[]}, {"Q1":[], "Q2":[]}, {"Q1":[], "Q2":[]}

epoch_idx = 0

In [None]:

initial_q_samples = q1_model.predict([np.array([create_weather_map() for i in range(20)]), 
                  state_action_pairs[:20,0:2], 
                  state_action_pairs[:20,2:4]]).flatten()
print(f"A sample of initial q-values are {initial_q_samples}")


2022-08-02 16:15:33.653056: W tensorflow/core/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz


array([[0.47059593],
       [0.5246563 ],
       [0.5206993 ],
       [0.47516003],
       [0.5044229 ],
       [0.51498663],
       [0.49184936],
       [0.5258844 ],
       [0.57368016],
       [0.4764757 ]], dtype=float32)

In [None]:
'''
snapshot_epoch201  = tracemalloc.take_snapshot()
display_top(snapshot_epoch201, limit=20)
'''

In [None]:
'''
snapshot_epoch154  = tracemalloc.take_snapshot()
display_top(snapshot_epoch154, limit=20)
'''

In [None]:
'''
snapshot_epoch93 = tracemalloc.take_snapshot()
display_top(snapshot_epoch93, limit=20)
'''

In [None]:
'''
num_validation_games = 20

for game_idx in range(num_validation_games) :
    create_greedy_policy_plot(create_weather_map(), q1_model, q2_model, epoch_idx=epoch_idx, verbose=True,
                              show=True, close=True, save=f"figures/Helicopter_NB2/greedy_policy_val{epoch_idx}.pdf")'''

In [None]:
'''
def compare_snapshots(snapshot1, snapshot2, key_type='lineno', limit=10):
    print("Filtering snapshot 1")
    snapshot1 = snapshot1.filter_traces((
        tracemalloc.Filter(False, "<frozen importlib._bootstrap>"),
        tracemalloc.Filter(False, "<unknown>"),
    ))
    
    print("Filtering snapshot 2")
    snapshot2 = snapshot2.filter_traces((
        tracemalloc.Filter(False, "<frozen importlib._bootstrap>"),
        tracemalloc.Filter(False, "<unknown>"),
    ))
    
    print("Getting top stats")
    top_stats1 = snapshot1.statistics(key_type)
    top_stats2 = snapshot2.statistics(key_type)
    total_mem1 = sum(stat.size for stat in top_stats1)
    total_mem2 = sum(stat.size for stat in top_stats2)
    print(f"total_mem1 = {total_mem1/1024/1024:.3f} MB")
    print(f"total_mem2 = {total_mem2/1024/1024:.3f} MB")
    print(f"change in mem = {total_mem2 - total_mem1}")
    
    print("Getting memory 1")
    memory_dict1 = {}
    for index, stat in enumerate(top_stats1[:limit], 1):
        key, size = stat.traceback[0].filename, stat.size / 1024 / 1024
        memory_dict1[key] = size
        
    print("Getting memory 2")
    memory_dict2 = {}
    for index, stat in enumerate(top_stats2[:limit], 1):
        key, size = stat.traceback[0].filename, stat.size / 1024 / 1024
        memory_dict2[key] = size
        
    print("Getting memory change")
    memory_increase = []
    for key, mem1 in memory_dict1.items() :
        if key not in memory_dict2 : continue
        memory_increase.append((key, memory_dict2[key] - mem1))
    for key, mem2 in memory_dict2.items() :
        if key in memory_dict1 : continue
        memory_increase.append((key, mem2))
        
    print("Sorting list")
    memory_increase.sort(key = lambda x : x[1], reverse=True)
    
    print("Printing results")
    for key, delta_mem in memory_increase[:limit] :
        print(f"{key}    {delta_mem:.3f} MB")
        
        
compare_snapshots(snapshot_epoch154, snapshot_epoch201)'''

In [None]:
'''
top_stats_compare = snapshot_end.compare_to(snapshot_start, 'lineno')

display_top(top_stats_compare)
    '''

In [None]:
'''
print("plt", sys.getsizeof(plt))

print("loss_fcn", sys.getsizeof(loss_fcn))
print("optimizer_q1", sys.getsizeof(optimizer_q1))
print("optimizer_q2", sys.getsizeof(optimizer_q2))
print("q1_model", sys.getsizeof(q1_model))
print("q2_model", sys.getsizeof(q2_model))

print("states", sys.getsizeof(states))
print("loss_record", sys.getsizeof(loss_record))
print("ref_loss_record", sys.getsizeof(ref_loss_record))

print("maxQ_record", sys.getsizeof(maxQ_record))
print("epoch_losses", sys.getsizeof(epoch_losses))
print("ref_losses", sys.getsizeof(ref_losses))

print("max_abs_q_values", sys.getsizeof(max_abs_q_values))
print("r_agents", sys.getsizeof(r_agents))
print("actions", sys.getsizeof(actions))

print("weather_maps", sys.getsizeof(weather_maps))
print("rewards", sys.getsizeof(rewards))
print("r_agents_p", sys.getsizeof(r_agents_p))
print("actions_p", sys.getsizeof(actions_p))
print("weather_map", sys.getsizeof(weather_map))
print("r_agent_p", sys.getsizeof(r_agent_p))

print("action_p", sys.getsizeof(action_p))
print("bootstrap_values", sys.getsizeof(bootstrap_values))
print("obs_returns", sys.getsizeof(obs_returns))
print("sample_weights", sys.getsizeof(sample_weights))
print("priority_weights", sys.getsizeof(priority_weights))

print("priority_weather", sys.getsizeof(priority_weather))
print("ref_returns", sys.getsizeof(ref_returns))
print("ref_loss", sys.getsizeof(ref_loss))
print("tape", sys.getsizeof(tape))

print("pred_returns", sys.getsizeof(pred_returns))
print("loss_value", sys.getsizeof(loss_value))
print("grads", sys.getsizeof(grads))

print("np_returns", sys.getsizeof(np_returns))
print("epoch_mean_loss", sys.getsizeof(epoch_mean_loss))
print("epoch_mean_ref_loss", sys.getsizeof(epoch_mean_ref_loss))
print("epoch_maxQ", sys.getsizeof(epoch_maxQ))
'''

In [None]:
'''
for key, desc in locals().items() :
    print(key)'''

In [None]:
'''for key, desc in globals().items() :
    print(key)'''

In [22]:
 state_action_pairs[:20,0:2]

array([[ 9,  4],
       [ 7,  5],
       [ 3,  8],
       [11,  0],
       [ 4,  6],
       [11,  5],
       [ 5,  4],
       [10,  1],
       [ 2,  8],
       [ 3,  3],
       [11,  4],
       [11,  5],
       [11,  5],
       [ 0,  7],
       [10,  6],
       [ 7,  8],
       [ 7,  7],
       [ 9,  8],
       [ 9,  1],
       [ 7,  7]])

In [23]:
state_action_pairs[:20,2:4]

array([[ 0,  0],
       [ 1,  1],
       [-1, -1],
       [-1,  1],
       [-1, -1],
       [ 1,  0],
       [ 1,  0],
       [-1, -1],
       [ 0,  1],
       [ 1,  1],
       [ 0, -1],
       [ 1,  1],
       [ 0, -1],
       [ 1, -1],
       [ 1, -1],
       [ 1,  0],
       [ 1,  0],
       [ 1,  1],
       [ 0,  0],
       [ 0,  1]])

In [None]:
## Start or continue training

start_time = time.time()
while epoch_idx < num_epochs or num_epochs < 0 :
    
    # Determine whether to plot current greedy policy
    if plot_policy_after_epochs > 0 and epoch_idx % plot_policy_after_epochs == 0 :
        create_greedy_policy_plot(create_weather_map(), q1_model, q2_model, epoch_idx=epoch_idx, verbose=True,
                                  show=False, close=True, save=f"figures/Helicopter_NB2/greedy_policy_epoch{epoch_idx}.pdf")
    
    # Determine whether to plot training curves
    if plot_monitors_after_epochs > 0 and epoch_idx > 0 and epoch_idx % plot_monitors_after_epochs == 0 :
        create_training_curves_plot(loss_record, ref_loss_record, maxQ_record, show=False, close=True,
                                    save="figures/Helicopter_NB2/training_curves.pdf")
        
    # Determine whether to switch q1 and q2
    if bootstrap_method == "other" and switch_after_epochs > 0 and epoch_idx > 0 and epoch_idx % switch_after_epochs == 0 :
        model_key_q1, model_key_q2 = model_key_q2, model_key_q1
        q1_model, q2_model         = q2_model, q1_model
        optimizer_q1, optimizer_q2 = optimizer_q2, optimizer_q1
        
    # Determine whether to copy q1 to q2
    if bootstrap_method == "clone" and clone_after_epochs > 0 and epoch_idx % clone_after_epochs == 0 :
        q2_model.set_weights(q1_model.get_weights()) 
    
    sys.stdout.write(f"\rEpoch {epoch_idx+1} / {num_epochs}  [t={time.time()-start_time:.2f}s]".ljust(110))
    
    # Perform one gradient update per batch
    epoch_losses, ref_losses, max_abs_q_values = [], [], []
    for batch_idx in range(math.ceil(num_train/batch_size)) :
        
        # Resolve sample indices to be used for this batch update
        batch_idx_low, batch_idx_high = batch_idx*batch_size, min((batch_idx+1)*batch_size, num_train)
        actual_batch_size = batch_idx_high - batch_idx_low
        if actual_batch_size == 0 : continue
        
        # Update the current epoch message to keep track of batch progress 
        sys.stdout.write(f"\rEpoch {epoch_idx+1} / {num_epochs} batch indices ({batch_idx_low}, {batch_idx_high}) / {num_train}  [t={time.time()-start_time:.2f}s]".ljust(110))
        
        # Get batch of states and generate a random exploration action for each
        r_agents = state_action_pairs[batch_idx_low:batch_idx_high,0:2]
        actions  = state_action_pairs[batch_idx_low:batch_idx_high,2:4]
        
        # Generate some unique weather maps to avoid over-training if we re-use maps
        weather_maps = np.array([create_weather_map() for i in range(actual_batch_size)])
        
        # For each state/action pair, apply the action to get the immediate reward, and also find the
        # greedy action in the next state from which to calculate bootstraps
        rewards, r_agents_p, actions_p, agent_p_is_terminal = [], [], [], []
        for r_agent, action, weather_map in zip(r_agents, actions, weather_maps) :
            r_agent_p, action_p, step_y, emp_return = r_agent, action, 1., 0.
            for step_idx in range(num_emp_steps) :
                if is_terminal(r_agent_p) : continue
                reward, r_agent_p = perform_action(weather_map, r_agent, action)
                action_p, _       = get_greedy_action(weather_map, r_agent_p, q1_model)
                emp_return       += step_y * reward
                step_y           *= gamma
            rewards   .append(emp_return)
            r_agents_p.append(r_agent_p)
            actions_p .append(action_p)
            agent_p_is_terminal.append(is_terminal(r_agent_p))
        rewards, r_agents_p, actions_p = np.array(rewards), np.array(r_agents_p), np.array(actions_p)
        
        # Calculate obs_returns using the observed immediate rewards plut gamma * bootstrap values
        bootstrap_values = q2_model.predict([weather_maps, r_agents_p, actions_p]).flatten()
        bootstrap_values = np.array([0. if is_term else q for is_term,q in zip(agent_p_is_terminal,bootstrap_values)])
        obs_returns      = rewards + (gamma**num_emp_steps) * bootstrap_values
        
        # Calculate weights to apply to regular batch and priority samples
        sample_weights  = (1.-priority_weight) * np.full(shape=(len(r_agents),), fill_value=1./len(r_agents))
        priority_weights = priority_weight * np.full(shape=(len(priority_states),), fill_value=1./len(priority_states))
        
        # Calculate weather for priority samples (changes each epoch)
        priority_weather = np.array([create_weather_map() for i in range(len(priority_states))])
        
        # Concatenate regular batch and priority samples
        if priority_weight > 0 and priority_weight <= 1 :
            sample_weights   = np.concatenate([sample_weights , priority_weights]).flatten()
            r_agents         = np.concatenate([r_agents       , priority_states ])
            actions          = np.concatenate([actions        , priority_actions])
            weather_maps     = np.concatenate([weather_maps   , priority_weather])
            obs_returns      = np.concatenate([obs_returns    , priority_returns]).flatten()
        
        # When using sample weights, we have to be careful with the object shapes. The loss function will
        # expect y_pred of shape (N,1) and y_true of shape (N,1), when the output shape is (1,), and sample
        # weights of shape (N,). When not using sample weights we can get away with being lazy on the shapes
        # of y_pred and y_true, but if we do this here then it will not correctly apply the correct sample
        # weight to the correct sample. Furthermore, MSE with sample weights will calculate mean(sw*res**2)
        # rather than normalising according to sum(sw*res**2)/sum(sw). This means we must also multiply sw 
        # by a factor of len(sw)/sum(sw) if we are to recover the MSE value that we expect. Note that we have
        # sum(sw)=1 by our construction above, but I still write it explicitly so the general idea is clear.
        sample_weights = sample_weights * len(sample_weights) / sample_weights.sum()
        obs_returns    = obs_returns.reshape((len(obs_returns),1))
        
        # ref_loss is the loss over only the priority state/action pairs, before gradient updates for
        # consistency with regular loss record. Since regular loss mixes (i) how close the priority samples
        # are to their correct values and (ii) how close other points are to their bootstrap-biased values,
        # the ref_loss removes the bootstrap samples so we can see whether these points are diverging
        ref_returns = q1_model([priority_weather, priority_states, priority_actions], training=False)
        ref_loss    = loss_fcn(priority_returns.reshape((len(ref_returns),1)), ref_returns.numpy())
        ref_losses.append(ref_loss)
                                
        # Apply gradient updates
        with tf.GradientTape() as tape:
            pred_returns     = q1_model([weather_maps, r_agents, actions], training=True)
            loss_value       = loss_fcn(obs_returns, pred_returns, sample_weight=sample_weights)
            grads            = tape.gradient(loss_value, q1_model.trainable_weights)
            optimizer_q1.apply_gradients(zip(grads, q1_model.trainable_weights))
            epoch_losses.append(loss_value.numpy())
            np_returns = np.fabs(pred_returns.numpy().flatten())
            max_q_idx  = np.argmax(np_returns)
            max_abs_q_values.append(np_returns.max())
                                            
    # Print epoch summary and end stdout line to keep this on screen
    epoch_mean_loss, epoch_mean_ref_loss, epoch_maxQ = np.mean(epoch_losses), np.mean(ref_losses), np.max(max_abs_q_values)
    sys.stdout.write(f"\rEpoch {epoch_idx+1} / {num_epochs}  [t={time.time()-start_time:.2f}s]  <loss = {epoch_mean_loss:.5f}, ref_loss = {epoch_mean_ref_loss:.5f}, max_Q = {epoch_maxQ:.1f}>\n".ljust(110))
    loss_record    [model_key_q1].append((epoch_idx, epoch_mean_loss    ))
    ref_loss_record[model_key_q1].append((epoch_idx, epoch_mean_ref_loss))
    maxQ_record    [model_key_q1].append((epoch_idx, epoch_maxQ         ))
                      
    # Manually iterate epoch index and make sure stdout not lagging
    epoch_idx += 1
    sys.stdout.flush()
        