#  RL to avoid static 2D weather using embedded values (no terminal state)

In [1]:
###
###  Import packages
###

import math, os, pickle, sys, time

import numpy as np

from matplotlib import pyplot as plt

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers     import Concatenate, Dense, Dropout, Input, Normalization, Rescaling
from tensorflow.keras.models     import Model
from tensorflow.keras.optimizers import Adam, SGD
from tensorflow.keras.losses     import MeanSquaredError

print("TensorFlow has found devices:")
for device in tf.config.list_physical_devices() :
    print(f"-  {device}")

#  Set matplotlib interactive backend via linemagic
#  -  interative backend allows plots to be shown in notebook
#  -  keep an eye on memory leak caused by creation of GUI objects which are not deleted
#  -  choose non-interative backend such as 'agg' to avoid memory leak, but cannot plot in notebook

%matplotlib agg




TensorFlow has found devices:
-  PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU')
-  PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')


In [2]:
###
###  Global config settings
###

global_config = dict(
    autoencoder_dir       = "models/2D_static_weather_20220808_emb20/autoencoder",
    encoder_dir           = "models/2D_static_weather_20220808_emb20/encoder",
    square_size           = 40                         ,
    max_weather_intensity = 18                         ,
    embedding_length      = 20                         ,
)

In [3]:
###
###  Class to allow configuration and interation with the environment
###

class Environment :
    
    action_dx = np.array([[-1, -1], [-1, 0], [-1, 1], [0, -1], [0, 0], [0, 1], [1, -1], [1, 0], [1, 1]])
    
    def __init__(self, autoencoder_dir=None, encoder_dir=None, square_size=10, max_weather_intensity=20, 
                 embedding_length=10, actions_dx=None, r_per_turn=0., r_per_dx=-0.2, r_per_b=-1., 
                 r_per_w=-1., gamma=.95) :
        if actions_dx is None :
            actions_dx = Environment.action_dx.copy()
        self.autoencoder_dir       = None
        self.encoder_dir           = None
        self.embedding_length      = None
        self.encoder_model         = None
        self.square_size           = None
        self.actions_dx            = None
        self.max_weather_intensity = None
        self.r_per_turn            = None
        self.r_per_dx              = None
        self.r_per_b               = None
        self.r_per_w               = None
        self.gamma                 = None
        self.configure_encoder_model(encoder_dir, embedding_length)
        self.configure_autoencoder_model(autoencoder_dir, embedding_length)
        self.configure_states_and_actions(square_size, actions_dx)
        self.configure_returns(max_weather_intensity, r_per_turn, r_per_dx, r_per_b, r_per_w, gamma)
        
    def __str__(self) :
        return self.get_summary()
    
    def configure_autoencoder_model(self, autoencoder_dir=None, embedding_length=None) :
        if autoencoder_dir  is None : autoencoder_dir  = self.autoencoder_dir
        if embedding_length is None : embedding_length = self.embedding_length
        self.autoencoder_dir  = autoencoder_dir
        self.embedding_length = embedding_length
        if type(autoencoder_dir) is str and len(autoencoder_dir) > 0 :
            self.load_autoencoder_model(autoencoder_dir)
    
    def configure_encoder_model(self, encoder_dir=None, embedding_length=None) :
        if encoder_dir      is None : encoder_dir      = self.encoder_dir
        if embedding_length is None : embedding_length = self.embedding_length
        self.encoder_dir      = encoder_dir
        self.embedding_length = embedding_length
        if type(encoder_dir) is str and len(encoder_dir) > 0 :
            self.load_encoder_model(encoder_dir)
        
    def configure_states_and_actions(self, square_size=None, actions_dx=None) :
        if square_size is None : square_size = self.square_size
        if square_size is None : 
            raise RuntimeError(f"no square_size provided and none already set")
        if actions_dx  is None : actions_dx = self.actions_dx
        if actions_dx  is None : actions_dx = Environment.actions_dx
        if type(actions_dx) != np.ndarray :
            raise RuntimeError(f"actions_dx must be a numpy array but found type {type(actions_dx)}")
        if len(actions_dx.shape) != 2 :
            raise RuntimeError(f"actions_dx must be a 2D array but found shape {actions_dx.shape}")
        if actions_dx.shape[1] != 2 :
            raise RuntimeError(f"axis 1 of actions_dx must have length 2 found shape {actions_dx.shape}")
        self.square_size = square_size
        self.actions_dx  = actions_dx.copy()
        self.num_states  = square_size * square_size
        self.num_actions = len(actions_dx)
        self.x_min       = 0
        self.x_max       = square_size - 1
        self.x_range     = self.x_max - self.x_min
        self.actions_dx  = actions_dx
        self.num_actions = len(actions_dx)
        self.a_min       = actions_dx.min()
        self.a_max       = actions_dx.max()
        self.a_range     = self.a_max - self.a_min
                
    def configure_returns(self, max_weather_intensity=None, r_per_turn=None, r_per_dx=None, r_per_b=None, 
                          r_per_w=None, gamma=None) :
        if max_weather_intensity is None : max_weather_intensity = self.max_weather_intensity
        if r_per_turn is None : r_per_turn = self.r_per_turn
        if r_per_dx   is None : r_per_dx   = self.r_per_dx
        if r_per_b    is None : r_per_b    = self.r_per_b
        if r_per_w    is None : r_per_w    = self.r_per_w
        if gamma      is None : gamma      = self.gamma
        self.max_weather_intensity = max_weather_intensity
        self.r_per_turn = r_per_turn
        self.r_per_dx   = r_per_dx
        self.r_per_b    = r_per_b
        self.r_per_w    = r_per_w
        self.gamma      = gamma
        
    def create_decoder_model_from_autoencoder(self, autoencoder_model=None) :
        if autoencoder_model is None : 
            autoencoder_model = self.autoencoder_model
        if autoencoder_model is None : 
            raise RuntimeError(f"no autoencoder_model provided and none already set")
        input_layer = Input(self.embedding_length)
        next_layer  = input_layer
        is_decoder_layer = False
        for layer in autoencoder_model.layers[1:] :
            if layer.name == "encoder_output" :
                is_decoder_layer = True
                continue
            if not is_decoder_layer : 
                continue
            next_layer = layer(next_layer)
        self.decoder_model = Model(input_layer, next_layer, name="decoder_model")
        self.decoder_model.compile(optimizer="sgd", loss="mse")
    
    def enforce_action_dx_is_valid(self, a_dx) :
        if self.is_action_dx_out_of_bounds(a_dx) :
            raise RuntimeError(f"action dx ({a_dx}) not in allowed list ({self.actions_dx})")
    
    def enforce_action_idx_is_valid(self, a_idx) :
        if self.is_action_idx_out_of_bounds(a_idx) :
            raise RuntimeError(f"action index ({a_idx}) is out of bounds (0-{self.num_actions-1})")
        
    def enforce_state_is_valid(self, x) :
        if self.is_x_out_of_bounds(x) :
            raise RuntimeError(f"position ({x}) is out of bounds")
        return True
    
    def generate_map(self, storm_density=0.003, calm_weather_intensity=2., storm_intensity_low=8., 
                     storm_intensity_high=14., storm_decay_rate_low=.7, storm_decay_rate_high=1.4) :
        ##    Generate storm locations and intensities
        square_size     = self.square_size
        storm_pad       = math.floor(storm_intensity_high / storm_decay_rate_low / 2.)
        mean_num_storms = storm_density * square_size * square_size
        num_storms      = np.random.poisson(lam=mean_num_storms)
        storms_list     = []
        for storm_idx in range(num_storms) :
            storm_x, storm_y = np.random.uniform(low=-storm_pad, high=square_size+storm_pad-1, size=(2,))
            storm_i  = np.random.uniform(low=storm_intensity_low, high=storm_intensity_high)
            storm_dr = np.random.uniform(low=storm_decay_rate_low, high=storm_decay_rate_high)
            storms_list.append([storm_i, storm_dr, storm_x, storm_y])
        ##    Generate weather map
        weather_map = np.zeros(shape=(square_size, square_size, 1))
        for x in range(square_size) :
            for y in range(square_size) :
                weather_intensity = calm_weather_intensity
                for storm_i, storm_dr, storm_x, storm_y in storms_list :
                    r = np.sqrt((x-storm_x)**2 + (y-storm_y)**2)
                    weather_intensity += max(0, storm_i - storm_dr*r)
                weather_map[x, y, 0] = min(np.random.poisson(lam=weather_intensity), self.max_weather_intensity)
        ##    Return new storm and weather containers
        return weather_map
    
    def generate_weather(self, num_data=1, verbose=False) :
        num_data = int(num_data)
        data_np = []
        for data_idx in range(num_data) :
            if verbose :
                sys.stdout.write(f"\rGenerating weather {1+data_idx} / {num_data}".ljust(100))
            data_np.append(self.generate_map())
        if verbose :
            sys.stdout.write(f"\n")
        data_np      = np.array(data_np)
        data         = tf.constant(data_np)
        embedded_data, embedded_data_np = None, None
        if not (self.encoder_model is None) :
            if verbose :
                sys.stdout.write(f"\rEmbedding weather")
            embedded_data = self.encoder_model(data, training=False)
            if verbose :
                sys.stdout.write(f"\n")
        self.data_np          = data_np
        self.data             = data
        self.embedded_data    = embedded_data
        self.embedded_data_np = embedded_data.numpy()
        return data_np, embedded_data
    
    def get_action_dx_from_index(self, a_idx) :
        return self.actions_dx[a_idx]
    
    def get_action_index_from_dx(self, a_dx) :
        return list(self.actions_dx).index(a_dx)
    
    def get_embedding_means_scales(self, num_data=-1, verbose=False) :
        if num_data > 0 :
            self.generate_weather(num_data=num_data, verbose=verbose)
        if not hasattr(self, "embedded_data_np") or (self.embedded_data_np is None) :
            raise RuntimeError("must specify num_data > 0 when no previous weather generated")
        means, scales = [], []
        for emb_idx in range(self.embedding_length) :
            v = self.embedded_data_np[:,emb_idx]
            means .append(np.mean(v))
            scales.append(np.std (v))
        self.means  = np.array(means)
        self.scales = np.array(scales)
        return self.means, self.scales
    
    def get_summary(self, write_to=None) :
        str_summary  =  "Environment config:\n"
        str_summary += f"    autoencoder_dir       | {self.autoencoder_dir}\n"
        str_summary += f"    encoder_dir           | {self.encoder_dir}\n"
        str_summary += f"    square_size           | {self.square_size}\n"
        str_summary += f"    max_weather_intensity | {self.max_weather_intensity}\n"
        str_summary += f"    embedding_length      | {self.embedding_length}\n"
        str_summary += f"    num_states            | {self.num_states}\n"
        str_summary += f"    x_min                 | {self.x_min}\n"
        str_summary += f"    x_max                 | {self.x_max}\n"
        str_summary += f"    num_actions           | {self.num_actions}\n"
        str_summary += f"    actions_dx            | {' '.join([f'{a_dx}' for a_dx in self.actions_dx])}\n"
        str_summary += f"    r_per_dx              | {self.r_per_dx}\n"
        str_summary += f"    r_per_b               | {self.r_per_b}\n"
        str_summary += f"    r_per_w               | {self.r_per_w}\n"
        str_summary += f"    gamma                 | {self.gamma}\n"
        str_summary += f"    Ns x Na               | {self.num_states*self.num_actions}\n"
        if not (write_to is None) : 
            write_to.write(str_summary)
        return str_summary
        
    def is_action_idx_out_of_bounds(self, a_idx) :
        a_idx = np.array(a_idx)
        return np.logical_or(a_idx < 0, a_idx >= self.num_actions)
        
    def is_action_dx_out_of_bounds(self, a_dx) :
        if type(a_dx) in [list, set, np.ndarray] :
            return np.array([self.is_action_dx_out_of_bounds(a_dxp) for a_dxp in a_dx])
        return a_dx not in self.actions_dx
        
    def is_x_out_of_bounds(self, x) :
        if x[0] < self.x_min : return True
        if x[0] > self.x_max : return True
        if x[1] < self.x_min : return True
        if x[1] > self.x_max : return True
        return False
    
    def load_autoencoder_model(self, autoencoder_dir=None, embedding_length=None, create_decoder=True) :
        if autoencoder_dir  is None : autoencoder_dir  = self.autoencoder_dir
        if embedding_length is None : embedding_length = self.embedding_length
        if autoencoder_dir is None :
            raise RuntimeError(f"no autoencoder_dir provided and none already set")
        self.autoencoder_model = keras.models.load_model(autoencoder_dir)
        self.autoencoder_model.trainable = False
        if create_decoder :
            self.create_decoder_model_from_autoencoder()
    
    def load_encoder_model(self, encoder_dir=None, embedding_length=None) :
        if encoder_dir      is None : encoder_dir      = self.encoder_dir
        if embedding_length is None : embedding_length = self.embedding_length
        if encoder_dir is None :
            raise RuntimeError(f"no encoder_dir provided and none already set")
        self.encoder_model   = keras.models.load_model(encoder_dir)
        encoder_output_shape = self.encoder_model.layers[-1].output_shape
        self.encoder_model.trainable = False
        if not (embedding_length is None) and embedding_length > 0 :
            if encoder_output_shape[-1] != embedding_length :
                raise RuntimeError(f"expected encoder_model to have output shape of (None, {embedding_length}) using embedding_length={embedding_length} but shape {encoder_output_shape} found")
        else :
            self.embedding_length = encoder_output_shape[-1]
            
    def perform_action(self, weather_map, x, a_idx) :
        self.enforce_state_is_valid(x)
        self.enforce_action_idx_is_valid(a_idx)
        #  Iterate agent position, if hit boundary then add penalty and return to original position 
        a_dx = self.actions_dx[a_idx]
        x_p  = x + a_dx
        normed_reward_b = 0
        if self.is_x_out_of_bounds(x_p) :
            reward_b = self.r_per_b
            x_p      = x
        #    Get movement-based reward
        normed_dx = np.sqrt(a_dx[0]*a_dx[0] + a_dx[1]*a_dx[1]) / np.sqrt(2)
        #    Get weather-based reward
        normed_reward_w = weather_map[x[0], x[1], 0] / self.max_weather_intensity
        #  Return total reward and new agent state
        reward = self.r_per_dx*normed_dx + self.r_per_b*normed_reward_b + self.r_per_w*normed_reward_w
        return reward, x_p
    
    def print_summary(self, write_to=None) :
        if write_to is None : write_to = sys.stdout
        self.get_summary(write_to=write_to)
          

In [4]:
env = Environment(
                  square_size           = global_config["square_size"]          ,
                  max_weather_intensity = global_config["max_weather_intensity"],
                  autoencoder_dir       = global_config["autoencoder_dir"]      ,
                  encoder_dir           = global_config["encoder_dir"]          ,
                  embedding_length      = global_config["embedding_length"]     
                 )

print(env)

#data_X_np, data_Xe = env.generate_weather(1000, verbose=True)


Metal device set to: Apple M1 Pro
Environment config:
    autoencoder_dir       | models/2D_static_weather_20220808_emb20/autoencoder
    encoder_dir           | models/2D_static_weather_20220808_emb20/encoder
    square_size           | 40
    max_weather_intensity | 18
    embedding_length      | 20
    num_states            | 1600
    x_min                 | 0
    x_max                 | 39
    num_actions           | 9
    actions_dx            | [-1 -1] [-1  0] [-1  1] [ 0 -1] [0 0] [0 1] [ 1 -1] [1 0] [1 1]
    r_per_dx              | -0.2
    r_per_b               | -1.0
    r_per_w               | -1.0
    gamma                 | 0.95
    Ns x Na               | 14400



In [5]:

def plot_embedded_data(embedded_data_np, num_cols=4, show=True, close=True) :
    n_dim = embedded_data_np.shape[1]
    
    fig  = plt.figure(figsize=(3*n_dim, 3*n_dim))
    for idx_1 in range(n_dim) :
        for idx_2 in range(1+idx_1) :
            ax = fig.add_subplot(n_dim, n_dim, 1 + idx_1*n_dim + idx_2)
            ax.tick_params(axis="both", which="both", direction="in", right=True, top=True)
            if idx_1 == idx_2 :
                v = embedded_data_np[:,idx_1]
                vs = list(v)
                vs.sort()
                f16, f84 = vs[int(0.16*(len(v)-1))], vs[int(0.84*(len(v)-1))]
                ax.hist(v, bins=30)
                ax.text(0, 1.25, f"$f_{{68}}$=[{f16:.1f}, {f84:.1f}]", va="bottom", ha="left",
                        fontsize=20, transform=ax.transAxes)
                ax.text(0, 1.05, f"μ={np.mean(v):.1f}, σ={np.std(v):.1f}", va="bottom", ha="left",
                        fontsize=20, transform=ax.transAxes)
                continue
            ax.hist2d(embedded_data_np[:,idx_2], embedded_data_np[:,idx_1], bins=30)
    fig.subplots_adjust(wspace=0.2, hspace=0.2)
    
    if show :
        plt.show(fig)
    
    if close :
        plt.clf()
        plt.close(fig)
        
    return fig

#fig = plot_embedded_data(env.embedded_data_np, show=True, close=True)
#del fig


In [6]:
###
###  Utility functions
###

def generate_directory_for_file_path(fname, print_msg_on_dir_creation=True) :
    """
    Create the directory structure needed to place file fname. Call this before fig.savefig(fname, ...) to 
    make sure fname can be created without a FileNotFoundError
    Input:
       - fname: str
                name of file you want to create a tree of directories to enclose
                also create directory at this path if fname ends in '/'
       - print_msg_on_dir_creation: bool, default = True
                                    if True then print a message whenever a new directory is created
    """
    while "//" in fname :
        fname = fname.replace("//", "/")
    dir_tree = fname.split("/")
    dir_tree = ["/".join(dir_tree[:i]) for i in range(1,len(dir_tree))]
    dir_path = ""
    for dir_path in dir_tree :
        if len(dir_path) == 0 : continue
        if not os.path.exists(dir_path) :
            os.mkdir(dir_path)
            if print_msg_on_dir_creation :
                print(f"Directory {dir_path} created")
            continue
        if os.path.isdir(dir_path) : 
            continue
        raise RuntimeError(f"Cannot create directory {dir_path} because it already exists and is not a directory")
    

In [16]:

class Experiment() :
    
    def __init__(self, run_config_dict=None, **kwargs) :
        if run_config_dict is None :
            run_config_dict = {}
        self.configure(run_config_dict={}, **kwargs)
        
    def __str__(self) :
        return self.get_summary()
        
    def configure(self, run_config_dict, **kwargs) :
        self.bootstrap                     = self.resolve_argument("bootstrap"                    , True    , bool      , run_config_dict, **kwargs)
        self.max_epochs                    = self.resolve_argument("max_epochs"                   , -1      , np.int32  , run_config_dict, **kwargs)
        self.batch_size                    = self.resolve_argument("batch_size"                   , 1       , np.int32  , run_config_dict, **kwargs)
        self.optimizer_type                = self.resolve_argument("optimizer_type"               , "sgd"   , str       , run_config_dict, **kwargs)
        self.learning_rate                 = self.resolve_argument("learning_rate"                , 1e-3    , np.float32, run_config_dict, **kwargs)
        self.plot_estimate_after_epochs    = self.resolve_argument("plot_estimate_after_epochs"   , 5       , np.int32  , run_config_dict, **kwargs)
        self.plot_monitors_after_epochs    = self.resolve_argument("plot_monitors_after_epochs"   , 5       , np.int32  , run_config_dict, **kwargs)
        self.save_objects_after_epochs     = self.resolve_argument("save_objects_after_epochs"    , 1       , np.int32  , run_config_dict, **kwargs)
        self.clone_after_epochs            = self.resolve_argument("clone_after_epochs"           , 5       , np.int32  , run_config_dict, **kwargs)
        self.generate_weather_after_epochs = self.resolve_argument("generate_weather_after_epochs", 5       , np.int32  , run_config_dict, **kwargs)
        self.shuffle_states_after_epochs   = self.resolve_argument("shuffle_states_after_epochs"  , 1       , np.int32  , run_config_dict, **kwargs)
        self.num_step_returns              = self.resolve_argument("num_step_returns"             , 1       , np.int32  , run_config_dict, **kwargs)
        self.run_tag                       = self.resolve_argument("run_tag"                      , "no_tag", str       , run_config_dict, **kwargs)
        self.run_idx                       = self.resolve_argument("run_idx"                      , 0       , np.int32  , run_config_dict, **kwargs)
        self.set_derived_constants()
        
    def export_run_config_dict(self) :
        run_config_dict = {}
        run_config_dict["bootstrap"                    ] = self.bootstrap
        run_config_dict["max_epochs"                   ] = self.max_epochs
        run_config_dict["batch_size"                   ] = self.batch_size
        run_config_dict["optimizer_type"               ] = self.optimizer_type
        run_config_dict["learning_rate"                ] = self.learning_rate
        run_config_dict["plot_estimate_after_epochs"   ] = self.plot_estimate_after_epochs
        run_config_dict["plot_monitors_after_epochs"   ] = self.plot_monitors_after_epochs
        run_config_dict["save_objects_after_epochs"    ] = self.save_objects_after_epochs
        run_config_dict["clone_after_epochs"           ] = self.clone_after_epochs
        run_config_dict["generate_weather_after_epochs"] = self.generate_weather_after_epochs
        run_config_dict["shuffle_states_after_epochs"  ] = self.shuffle_states_after_epochs
        run_config_dict["num_step_returns"             ] = self.num_step_returns
        run_config_dict["run_tag"                      ] = self.run_tag
        run_config_dict["run_idx"                      ] = self.run_idx
        return run_config_dict
        
    def create_config(self, env, verbose=True) :
        #  Create config message
        config_str = ""
        config_str += f"="*114 + "\n"
        config_str += env.get_summary()
        config_str += f"="*114 + "\n"
        config_str += self.get_summary()
        config_str += f"="*114 + "\n"
        #  Print to stdout
        if verbose :
            print(config_str)
        #  Print to file
        config_fname = f"{self.top_directory}/config.txt"
        generate_directory_for_file_path(config_fname, print_msg_on_dir_creation=verbose)
        with open(config_fname, "w") as config_file :
            config_file.write(config_str)
            if hasattr(self, "q_model") :
                config_file.write("\nq-value model config:\n")
                self.q_model.summary(print_fn=lambda x: config_file.write(x + '\n'))
    
    def get_bs_values_for_embedding(self, embedding, verbose=True) :
        if verbose : print("Evaluating bootstrap values")
        embedding  = list(embedding.numpy())
        embeddings = [embedding for idx in range(self.num_states)]
        embeddings = tf.constant(embeddings)
        bs_values = self.bs_model([embeddings, self.ordered_states], training=False)
        bs_values = bs_values.numpy().reshape((self.square_size, self.square_size, self.num_actions))
        if verbose : print("Bootstrap values evaluated")
        return bs_values
    
    def get_q_values_for_embedding(self, embedding, verbose=True) :
        if verbose : print("Evaluating q values")
        embedding  = list(embedding.numpy())
        embeddings = [embedding for idx in range(self.num_states)]
        embeddings = tf.constant(embeddings)
        q_values = self.q_model([embeddings, self.ordered_states], training=False)
        q_values = q_values.numpy().reshape((self.square_size, self.square_size, self.num_actions))
        if verbose : print("q values evaluated")
        return q_values
    
    def get_summary(self, write_to=None) :
        str_summary  =  "Experiment config:\n"
        str_summary += f"    bootstrap                     | {self.bootstrap}\n"
        str_summary += f"    max_epochs                    | {self.max_epochs}\n"
        str_summary += f"    batch_size                    | {self.batch_size}\n"
        str_summary += f"    optimizer_type                | {self.optimizer_type:.7}\n"
        str_summary += f"    learning_rate                 | {self.learning_rate:.7}\n"
        str_summary += f"    plot_estimate_after_epochs    | {self.plot_estimate_after_epochs}\n"
        str_summary += f"    plot_monitors_after_epochs    | {self.plot_monitors_after_epochs}\n"
        str_summary += f"    save_objects_after_epochs     | {self.save_objects_after_epochs}\n"
        str_summary += f"    clone_after_epochs            | {self.clone_after_epochs}\n"
        str_summary += f"    generate_weather_after_epochs | {self.generate_weather_after_epochs}\n"
        str_summary += f"    shuffle_states_after_epochs   | {self.shuffle_states_after_epochs}\n"
        str_summary += f"    num_step_returns              | {self.num_step_returns}\n"
        str_summary += f"    run_tag                       | {self.run_tag}\n"
        str_summary += f"    run_idx                       | {self.run_idx}\n"
        if not (write_to is None) : 
            write_to.write(str_summary)
        return str_summary
    
    def plot_training_curves(self, verbose=True, save=False, show=False, close=False) :
        if verbose :
            print("Plotting training curves")

        fig = plt.figure(figsize=(25,20))
        fig.set_facecolor("white")
        fig.set_alpha(1)

        ax1 = fig.add_subplot(3, 1, 1)
        ax1.grid(True, which='both')
        ax1.tick_params(axis="both", which="both", right=True, top=True, direction="in", labelsize=30)
        ax1.set_title(r"Mean loss per batch", fontsize=30)
        ax1.xaxis.set_ticklabels([])
        ax1.plot(self.epochs_record, self.loss_record, "o-", c="r", lw=2, ms=4)
        ax1.set_yscale("log")

        ax2 = fig.add_subplot(3, 1, 2)
        ax2.grid(True, which='both')
        ax2.tick_params(axis="both", which="both", right=True, top=True, direction="in", labelsize=30)
        ax2.set_title(r"Min $q(s,a)$ over all batches", fontsize=30)
        ax2.xaxis.set_ticklabels([])
        ax2.plot(self.epochs_record, self.min_q_record, "o-", c="r", lw=2, ms=4)
        ax2.axhline(0, ls="--", lw=2, c="gray")

        ax3 = fig.add_subplot(3, 1, 3)
        ax3.grid(True, which='both')
        ax3.tick_params(axis="both", which="both", right=True, top=True, direction="in", labelsize=30)
        ax3.set_title(r"Max $q(s,a)$ over all batches", fontsize=30)
        ax3.set_xlabel(r"Epoch", labelpad=15, fontsize=30)
        ax3.plot(self.epochs_record, self.max_q_record, "o-", c="r", lw=2, ms=4)
        ax3.axhline(0, ls="--", lw=2, c="gray")

        fig.subplots_adjust(hspace=0.2)

        if save :
            fname = f"{self.top_directory}/training_curve.pdf"
            generate_directory_for_file_path(fname, print_msg_on_dir_creation=verbose)
            plt.savefig(fname, bbox_inches="tight")
            if verbose :
                print(f"Training curves plot saved to file {fname}")
                
        if show :
            plt.show(fig)
            
        if close :
            plt.clf()
            plt.close(fig)
            del fig, ax1, ax2, ax3
            fig, ax1, ax2, ax3 = None, None, None, None
            plt.close("all")
            if verbose :
                print(f"Training curves plot closed")

        return fig, ax1, ax2, ax3
    
    def plot_greedy_policy_for_weather(self, env, verbose=True, save=False, show=False, close=False) :
     
        #  Keep track of how long plotting takes, to help inform how often to call this function    
        start_time = time.time()

        #  Calculate q-values and decoded weather
        weather_maps_np, embeddings = env.generate_weather(verbose=verbose)
        decoded_weather_maps        = env.decoder_model(embeddings, training=False)
        decoded_weather_map_np      = decoded_weather_maps[0].numpy()
        decoded_weather_map_np     /= env.max_weather_intensity
        weather_map_np, embedding   = weather_maps_np[0], embeddings[0]
        weather_map_np             /= env.max_weather_intensity
        q_values                    = self.get_q_values_for_embedding(embedding, verbose=verbose)

        #  Set up plot
        fig = plt.figure(figsize=(9*self.square_size/11.5,9*self.square_size/11.5))
        
        ax1 = fig.add_subplot(2, 2, 1)
        ax1.set_xlim(-0.5, self.square_size-0.5)
        ax1.set_ylim(-0.5, self.square_size-0.5)
        ax1.tick_params(axis="both", which="both", right=True, top=True, direction="in", labelsize=16)
        
        ax2 = fig.add_subplot(2, 2, 2)
        ax2.set_xlim(-0.5, self.square_size-0.5)
        ax2.set_ylim(-0.5, self.square_size-0.5)
        ax2.tick_params(axis="both", which="both", right=True, top=True, direction="in", labelsize=16)
        
        ax3 = fig.add_subplot(2, 2, 3)
        ax3.set_xlim(-0.5, self.square_size-0.5)
        ax3.set_ylim(-0.5, self.square_size-0.5)
        ax3.tick_params(axis="both", which="both", right=True, top=True, direction="in", labelsize=16)
        
        fig.subplots_adjust(hspace=1./9., wspace=1./9.)

        #  Draw original and decoded weather
        ax1.imshow(weather_map_np        [:,:,0].transpose(), origin="lower", alpha=0.5, cmap="Greys", vmin=0, vmax=1)
        ax2.imshow(decoded_weather_map_np[:,:,0].transpose(), origin="lower", alpha=0.5, cmap="Greys", vmin=0, vmax=1)

        #  Draw arrows by looping over states and finding greedy action according to q-models
        max_values = np.zeros(shape=(self.square_size, self.square_size))
        for x in range(self.square_size) :
            for y in range(self.square_size) :
                if verbose :
                    sys.stdout.write(f"\rEvaluating greedy policy for agent state ({x}, {y})".ljust(100))
                a_idx  = np.argmax(q_values[x,y])
                dx, dy = env.get_action_dx_from_index(a_idx)
                max_values[x, y] = np.max(q_values[x,y])
                if dx == 0 and dy == 0 :
                    ax1.plot(x, y, "o", markersize=8, c="b", alpha=1)
                    ax2.plot(x, y, "o", markersize=8, c="b", alpha=1)
                else :
                    ax1.arrow(x - 0.3*dx, y - 0.3*dy, 0.6*dx, 0.6*dy, head_width=0.25, length_includes_head=True, color="b")
                    ax2.arrow(x - 0.3*dx, y - 0.3*dy, 0.6*dx, 0.6*dy, head_width=0.25, length_includes_head=True, color="b")
        
        
        #  Draw max action-values
        max_values -= max_values.min()
        max_values /= max_values.max()
        ax3.imshow(1 - max_values.transpose(), origin="lower", alpha=0.5, cmap="Greys", vmin=0, vmax=1)

        #  Draw text boxes displaying title and num. epochs
        ax1.text(0, 1.01, "Weather intensity and greedy policy", transform=ax1.transAxes, 
                 fontsize=17, weight="bold", ha="left", va="bottom")
        ax2.text(0, 1.01, "Decoded weather intensity and greedy policy", transform=ax2.transAxes, 
                 fontsize=17, weight="bold", ha="left", va="bottom")
        ax2.text(1, 1.01, f"After {self.epoch_idx} epochs", ha="right", va="bottom", weight="bold", 
                 transform=ax2.transAxes, fontsize=17)
        ax3.text(0, 1.01, "max$_a$ $q(s,a)$", transform=ax3.transAxes, fontsize=17, weight="bold", ha="left", va="bottom")
        
        #  Verbose messaging
        if verbose :
            sys.stdout.write(f"\nPlot created in {time.time()-start_time:.2f}s".ljust(100)+"\n")

        #  Save
        if save :
            fname = f"{self.top_directory}/greedy_policy_epoch{self.epoch_idx}.pdf"
            generate_directory_for_file_path(fname, print_msg_on_dir_creation=verbose)
            plt.savefig(fname, bbox_inches="tight")
            if verbose :
                print(f"Greedy policy plot saved to file {fname}")
               
        #  Show 
        if show :
            plt.show(fig)
            
        #  Close
        if close :
            plt.clf()
            plt.close(fig)
            del fig, ax1, ax2
            fig, ax1, ax2 = None, None, None
            plt.close("all")
            if verbose :
                print(f"Greedy policy plot closed")

        #  Return figure and axis
        return fig, ax1, ax2
    
    def print_summary(self, write_to=None) :
        if write_to is None : write_to = sys.stdout
        self.get_summary(write_to=write_to)
        
    def record_monitors(self, epoch_idx, loss, min_q, max_q, mse_true) :
        self.epochs_record.append(epoch_idx)
        self.loss_record  .append(loss)
        self.min_q_record .append(min_q)
        self.max_q_record .append(max_q)
        
    def resolve_argument(self, arg_name, arg_default, dtype, config_dict, **kwargs) :
        if arg_name in kwargs :
            arg_val = kwargs.get(arg_name)
        else :
            arg_val = config_dict.get(arg_name, arg_default)
        return dtype(arg_val)
    
    def run(self, env, initialise=True, debug=False, info=True) :
        
        #   Re-initialise training objects to start training from scratch - must be done on first run() call
        if initialise :
            self._initialise_run(env, debug=debug, info=info)
        
        #   At start of new training loop, set constants and plot value functions
        start_time = time.time()
        self.plot_greedy_policy_for_weather(env, verbose=debug, save=True, show=False, close=False)
        
        #   Start epochs loop
        while self.epoch_idx < self.max_epochs or self.max_epochs < 0 :

            #   Start epoch printout message
            if debug  : print(f"Epoch {self.epoch_idx+1} / {self.max_epochs}  [t={time.time()-start_time:.2f}s")
            elif info : sys.stdout.write(f"Epoch {self.epoch_idx+1} / {self.max_epochs}  [t={time.time()-start_time:.2f}s]")

            #   Set bootstrap model
            if self.bootstrap and self.clone_after_epochs > 0 and self.epoch_idx % self.clone_after_epochs == 0 and self.epoch_idx > 0 :
                if debug : print("Copying q_model weights to bs_model")
                self.bs_model.set_weights(self.q_model.get_weights()) 
                
            #   Generate new weather
            if self.generate_weather_after_epochs > 0 and self.epoch_idx % self.generate_weather_after_epochs == 0 and self.epoch_idx > 0 :
                if debug : print("Generating new weather")
                self._initialise_weather(env, verbose=debug)
                
            #   Shuffle states to ensure each state does not always see the same weather
            if self.shuffle_states_after_epochs > 0 and self.epoch_idx % self.shuffle_states_after_epochs == 0 and self.epoch_idx > 0 :
                self._shuffle_states(verbose=debug)
                
            #   Start batches loop
            epoch_losses, min_q_values, max_q_values = [], [], []
            for batch_idx in range(math.ceil(self.num_states/self.batch_size)) :
                
                #   Resolve sample indices to be used for this batch update
                batch_idx_low     = batch_idx*self.batch_size
                batch_idx_high    = min((batch_idx+1)*self.batch_size, self.num_states)
                actual_batch_size = batch_idx_high - batch_idx_low
                if actual_batch_size == 0 : continue

                #   Update the current epoch message to keep track of batch progress 
                if debug  : print(f"Epoch {self.epoch_idx+1} / {self.max_epochs} batch indices ({batch_idx_low}, {batch_idx_high}) / {self.num_states}  [t={time.time()-start_time:.2f}s]")
                elif info : sys.stdout.write(f"\rEpoch {self.epoch_idx+1} / {self.max_epochs} batch indices ({batch_idx_low}, {batch_idx_high}) / {self.num_states}  [t={time.time()-start_time:.2f}s]")

                #   Get states to update for this batch
                batch_train_weather_np = self.train_weather_np[batch_idx_low:batch_idx_high]
                batch_train_embeddings = self.train_embeddings[batch_idx_low:batch_idx_high]
                batch_train_x          = self.shuffled_states [batch_idx_low:batch_idx_high]
                
                #   Evaluate q_target for this batch (and update global monitor value)
                batch_q_target = []
                for weather_np, embedding, x in zip(batch_train_weather_np, batch_train_embeddings, batch_train_x) :
                    if self.bootstrap :
                        bs_values = self.get_bs_values_for_embedding(embedding, verbose=False)
                    q_values  = self.get_q_values_for_embedding (embedding, verbose=False)
                    q_a       = []
                    for a_idx in range(self.num_actions) :
                        g, x_p, a_idx_p, step_y = 0, x, a_idx, 1.
                        for step_idx in range(self.num_step_returns) :
                            r, x_p  = env.perform_action(weather_np, x_p, a_idx_p)
                            g      += step_y * r
                            step_y *= env.gamma
                            a_idx_p = np.argmax(q_values[x_p[0], x_p[1]])
                        if self.bootstrap :
                            g += step_y * bs_values[x_p[0], x_p[1], a_idx_p]
                        q_a.append(g)
                    batch_q_target.append(q_a)
                if debug : print(f"batch_q_target between {np.min(batch_q_target):.5} and {np.max(batch_q_target):.5}")
                batch_q_target = tf.constant(batch_q_target)

                #   Apply gradient updates and store monitor values
                with tf.GradientTape() as tape:
                    batch_q_model = self.q_model([batch_train_embeddings, batch_train_x], training=True)
                    batch_loss    = self.loss_fcn(batch_q_target, batch_q_model)
                    if not np.isfinite(batch_loss.numpy()) :
                        raise RuntimeError(f"Batch loss is NaN, training aborted without gradient update")
                    grads = tape.gradient(batch_loss, self.q_model.trainable_weights)
                    self.optimizer.apply_gradients(zip(grads, self.q_model.trainable_weights))
                    
                #  Update monitors
                epoch_losses.append(batch_loss.numpy())
                min_q_values.append(batch_q_model.numpy().min())
                max_q_values.append(batch_q_model.numpy().max())
                if debug :
                    print(f"Batch loss: {epoch_losses[-1]:.5}")
                    print(f"Batch min-q: {min_q_values[-1]:.5}")
                    print(f"Batch max-q: {max_q_values[-1]:.5}")
                
                    
            #   Calculate post-epoch MSE wrt true values, and other monitors
            epoch_mean_loss, epoch_min_q, epoch_max_q = np.mean(epoch_losses), np.min(min_q_values), np.max(max_q_values)
            
            #   Store monitor values
            self.epochs_record.append(self.epoch_idx)
            self.loss_record  .append(epoch_mean_loss)
            self.min_q_record .append(epoch_min_q)
            self.max_q_record .append(epoch_max_q)
            
            #   End print line
            sys.stdout.write(f"\rEpoch {self.epoch_idx+1} / {self.max_epochs}  [t={time.time()-start_time:.2f}s]  <loss = {epoch_mean_loss:.5f}, min_q = {epoch_min_q:.1f}, max_q = {epoch_max_q:.1f}>".ljust(100)+"\n")
    
            #   Update number of completed epochs
            self.epoch_idx += 1
        
            #   Plot value function
            if self.plot_estimate_after_epochs > 0 and self.epoch_idx % self.plot_estimate_after_epochs == 0 :
                self.plot_greedy_policy_for_weather(env, verbose=debug, save=True, show=False, close=False)

            #   Plot training curves
            if self.plot_monitors_after_epochs > 0 and self.epoch_idx % self.plot_monitors_after_epochs == 0 :
                self.plot_training_curves(verbose=debug, save=True, close=True)

            #   Save objects
            if self.save_objects_after_epochs > 0 and self.epoch_idx % self.save_objects_after_epochs == 0 :
                self.save_run_progress(verbose=debug)
        
        #   Make sure final plots and objects are saved
        self.plot_greedy_policy_for_weather(env, verbose=debug, save=True, show=False, close=False)
        self.plot_value_functions(bs_values, q_values, save=True, close=True)
        self.plot_training_curves(save=True, close=True)
        self.save_run_progress()
    
    def save_run_progress(self, verbose=True) :
        fname   = f"{self.top_directory}/saved_objects.pickle"
        to_save = self.export_run_config_dict()
        to_save["run:epoch_idx"]        = self.epoch_idx
        to_save["run:epochs_record"]    = self.epochs_record
        to_save["run:loss_record"]      = self.loss_record
        to_save["run:min_q_record"]     = self.min_q_record
        to_save["run:max_q_record"]     = self.max_q_record
        to_save["run:embedding_length"] = self.embedding_length
        to_save["run:square_size"]      = self.square_size
        to_save["run:num_states"]       = self.num_states
        to_save["run:num_actions"]      = self.num_actions
        to_save["general:bootstrap"                    ] = self.bootstrap
        to_save["general:max_epochs"                   ] = self.max_epochs
        to_save["general:batch_size"                   ] = self.batch_size
        to_save["general:optimizer_type"               ] = self.optimizer_type
        to_save["general:learning_rate"                ] = self.learning_rate
        to_save["general:plot_estimate_after_epochs"   ] = self.plot_estimate_after_epochs
        to_save["general:plot_monitors_after_epochs"   ] = self.plot_monitors_after_epochs
        to_save["general:save_objects_after_epochs"    ] = self.save_objects_after_epochs
        to_save["general:clone_after_epochs"           ] = self.clone_after_epochs
        to_save["general:generate_weather_after_epochs"] = self.generate_weather_after_epochs
        to_save["general:shuffle_states_after_epochs"  ] = self.shuffle_states_after_epochs
        to_save["general:num_step_returns"             ] = self.num_step_returns
        to_save["general:run_tag"                      ] = self.run_tag
        to_save["general:run_idx"                      ] = self.run_idx
        generate_directory_for_file_path(fname, print_msg_on_dir_creation=verbose)
        pickle.dump(to_save, open(fname,"wb"))
        tf_log_level = tf.get_logger().level
        tf.get_logger().setLevel('WARNING')
        self.q_model .save(f"{self.top_directory}/q_model")
        if self.bootstrap :
            self.bs_model.save(f"{self.top_directory}/bs_model")
        tf.get_logger().setLevel(tf_log_level)
        if verbose :
            print(f"Run objects saved to file {fname}")
            print(f"q model saved to file {self.top_directory}/q_model")
            if self.bootstrap :
                print(f"Bootstrap model saved to file {self.top_directory}/bs_model")
        
    def set_derived_constants(self) :
        self.top_directory = f"figures/Helicopter_2D_from_embedding/{self.run_tag}/experiment_{self.run_idx}"
                
    def _create_q_model(self, env, name=None) :
        means, scales  = env.get_embedding_means_scales()
        input_layer_w  = Input((env.embedding_length,))
        next_layer_w   = Normalization(mean=means, variance=scales**2)(input_layer_w)
        next_layer_w   = Dense(100, activation="tanh")(next_layer_w)
        input_layer_x  = Input((2,))
        next_layer_x   = Rescaling(2./env.x_range, offset=-(env.x_max+env.x_min)/env.x_range)(input_layer_x)
        next_layer_x   = Dense(100, activation="tanh")(next_layer_x)
        next_layer     = Concatenate()([next_layer_w, next_layer_x])
        next_layer     = Dense(500, activation="relu")(next_layer)
        next_layer     = Dense(500, activation="relu")(next_layer)
        next_layer     = Dense(100, activation="relu")(next_layer)
        output_layer   = Dense(env.num_actions, activation="linear")(next_layer)
        model          = Model([input_layer_w, input_layer_x], output_layer, name=name)
        model.compile(loss="mse", optimizer="sgd")
        return model
        
    def _initialise_weather(self, env, verbose=True) :
        if verbose :
            print(f"Initialising weather")
        train_weather_np, train_embeddings = env.generate_weather(self.num_states, verbose=verbose)
        self.train_weather_np = train_weather_np
        self.train_embeddings = train_embeddings
        return self.train_weather_np, self.train_embeddings
    
    def _initialise_keras_objects(self, env, verbose=True) :
        if verbose :
            print(f"Initialising keras objects")
        self.loss_fcn = MeanSquaredError()
        self.q_model  = self._create_q_model(env, "q_model")
        self.bs_model = None
        if self.bootstrap :
            self.bs_model = self._create_q_model(env, "bs_model")
            self.bs_model.set_weights(self.q_model.get_weights())
        if self.optimizer_type.lower() == "sgd" :
            self.optimizer = SGD(learning_rate=self.learning_rate)
        elif self.optimizer_type.lower() == "adam" :
            self.optimizer = Adam(learning_rate=self.learning_rate)
        else :
            raise NotImplementedError(f"optimizer_type = {self.optimizer_type} not recognised by method initialise_keras_objects")
        return self.loss_fcn, self.q_model, self.bs_model, self.optimizer

    def _initialise_monitor_records(self, verbose=True) :
        if verbose :
            print(f"Initialising monitor records")
        self.epochs_record = []
        self.loss_record   = []
        self.min_q_record  = []
        self.max_q_record  = []
        
    def _initialise_run(self, env, debug=False, info=True) :
        self.epoch_idx        = 0
        self.embedding_length = env.embedding_length
        self.square_size      = env.square_size
        self.num_states       = env.num_states
        self.num_actions      = env.num_actions
        self._initialise_monitor_records(verbose=info)
        self._initialise_weather(env, verbose=info)
        self._initialise_keras_objects(env, verbose=info)
        self._shuffle_states(verbose=info)
        self.create_config(env, verbose=info)
        
    def _shuffle_states(self, verbose=True) :
        x_pairs = []
        for x1 in range(self.square_size) :
            for x2 in range(self.square_size) :
                x_pairs.append((x1, x2))
        self.ordered_states  = tf.constant(np.array(x_pairs))
        np.random.shuffle(x_pairs)
        self.shuffled_states = tf.constant(np.array(x_pairs))
        if verbose :
            print(f"{len(x_pairs)} (x1,x2) states created and shuffled")
        return self.shuffled_states
        

In [17]:
exp = Experiment(batch_size=20,
                 plot_estimate_after_epochs=1,
                 plot_monitors_after_epochs=1,
                 clone_after_epochs=1,
                 generate_weather_after_epochs=10,
                 shuffle_states_after_epochs=1,
                 num_step_returns=1,
                 optimizer_type="adam",
                 learning_rate=3e-3,
                 run_idx=0,
                 bootstrap=False,
                 run_tag="no_bootstrap")

# Metal seems to hang sometimes when using GPU, so leave on CPU overnight
# Seems to run just as fast anyway (surprising)

with tf.device('CPU:0') :
    exp.run(env, debug=False)

Initialising monitor records
Initialising weather
Generating weather 1600 / 1600                                                                     
Embedding weather
Initialising keras objects
1600 (x1,x2) states created and shuffled
Environment config:
    autoencoder_dir       | models/2D_static_weather_20220808_emb20/autoencoder
    encoder_dir           | models/2D_static_weather_20220808_emb20/encoder
    square_size           | 40
    max_weather_intensity | 18
    embedding_length      | 20
    num_states            | 1600
    x_min                 | 0
    x_max                 | 39
    num_actions           | 9
    actions_dx            | [-1 -1] [-1  0] [-1  1] [ 0 -1] [0 0] [0 1] [ 1 -1] [1 0] [1 1]
    r_per_dx              | -0.2
    r_per_b               | -1.0
    r_per_w               | -1.0
    gamma                 | 0.95
    Ns x Na               | 14400
Experiment config:
    bootstrap                     | False
    max_epochs                    | -1
    batch_siz

  max_values /= max_values.max()


Epoch 55 / -1  [t=3122.45s]  <loss = 0.02367, min_q = -1.4, max_q = -0.1>                          
Epoch 56 / -1  [t=3154.19s]  <loss = 0.02092, min_q = -1.3, max_q = 0.0>                           
Epoch 57 / -1  [t=3186.44s]  <loss = 0.02248, min_q = -1.2, max_q = -0.1>                          
Epoch 58 / -1  [t=3218.27s]  <loss = 0.02168, min_q = -1.2, max_q = -0.1>                          
Epoch 59 / -1  [t=3250.57s]  <loss = 0.02250, min_q = -1.2, max_q = -0.1>                          
Epoch 60 / -1  [t=3282.42s]  <loss = 0.02153, min_q = -1.3, max_q = -0.1>                          
Epoch 61 / -1  [t=3332.40s]  <loss = 0.02012, min_q = -1.5, max_q = 0.0>                           
Epoch 62 / -1  [t=3364.36s]  <loss = 0.02336, min_q = -1.3, max_q = -0.1>                          
Epoch 63 / -1  [t=3396.51s]  <loss = 0.02430, min_q = -1.3, max_q = -0.1>                          
Epoch 64 / -1  [t=3428.87s]  <loss = 0.02229, min_q = -1.4, max_q = -0.0>                          


Epoch 137 / -1  [t=5683.54s]  <loss = 0.01948, min_q = -1.2, max_q = -0.1>                         
Epoch 138 / -1  [t=5711.96s]  <loss = 0.02004, min_q = -1.4, max_q = -0.1>                         
Epoch 139 / -1  [t=5740.91s]  <loss = 0.02135, min_q = -1.4, max_q = -0.1>                         
Epoch 140 / -1  [t=5769.66s]  <loss = 0.02044, min_q = -1.4, max_q = -0.1>                         
Epoch 141 / -1  [t=5815.45s]  <loss = 0.01860, min_q = -1.2, max_q = 0.0>                          
Epoch 142 / -1  [t=5844.84s]  <loss = 0.01872, min_q = -1.3, max_q = -0.1>                         
Epoch 143 / -1  [t=5874.58s]  <loss = 0.01997, min_q = -1.3, max_q = -0.1>                         
Epoch 144 / -1  [t=5903.82s]  <loss = 0.01906, min_q = -1.4, max_q = -0.1>                         
Epoch 145 / -1  [t=5933.01s]  <loss = 0.01876, min_q = -1.4, max_q = -0.1>                         
Epoch 146 / -1  [t=5962.36s]  <loss = 0.01808, min_q = -1.3, max_q = -0.0>                         


  max_values /= max_values.max()


Epoch 195 / -1  [t=7488.79s]  <loss = 0.01810, min_q = -1.2, max_q = -0.1>                         
Epoch 196 / -1  [t=7518.14s]  <loss = 0.01894, min_q = -1.6, max_q = -0.1>                         
Epoch 197 / -1  [t=7547.12s]  <loss = 0.01667, min_q = -1.2, max_q = -0.1>                         
Epoch 198 / -1  [t=7576.48s]  <loss = 0.01767, min_q = -1.3, max_q = -0.1>                         
Epoch 199 / -1  [t=7605.60s]  <loss = 0.01885, min_q = -1.4, max_q = 0.0>                          
Epoch 200 / -1  [t=7634.88s]  <loss = 0.01873, min_q = -1.3, max_q = -0.1>                         
Epoch 201 / -1  [t=7681.99s]  <loss = 0.01829, min_q = -1.3, max_q = 1.1>                          
Epoch 202 / -1  [t=7712.11s]  <loss = 0.01863, min_q = -1.3, max_q = -0.1>                         
Epoch 203 / -1  [t=7741.89s]  <loss = 0.01859, min_q = -1.4, max_q = -0.1>                         
Epoch 204 / -1  [t=7771.79s]  <loss = 0.01971, min_q = -1.2, max_q = 0.5>                          


Epoch 277 / -1  [t=10026.80s]  <loss = 0.01787, min_q = -1.2, max_q = -0.1>                        
Epoch 278 / -1  [t=10055.57s]  <loss = 0.01808, min_q = -1.3, max_q = 0.3>                         
Epoch 279 / -1  [t=10084.58s]  <loss = 0.01798, min_q = -1.2, max_q = -0.1>                        
Epoch 280 / -1  [t=10113.32s]  <loss = 0.01725, min_q = -1.3, max_q = -0.1>                        
Epoch 281 / -1  [t=10159.28s]  <loss = 0.01796, min_q = -1.3, max_q = -0.1>                        
Epoch 282 / -1  [t=10188.17s]  <loss = 0.01721, min_q = -1.3, max_q = -0.1>                        
Epoch 283 / -1  [t=10217.22s]  <loss = 0.01803, min_q = -1.3, max_q = -0.1>                        
Epoch 284 / -1  [t=10245.95s]  <loss = 0.01836, min_q = -1.3, max_q = -0.1>                        
Epoch 285 / -1  [t=10275.01s]  <loss = 0.01867, min_q = -1.3, max_q = -0.1>                        
Epoch 286 / -1  [t=10303.78s]  <loss = 0.01718, min_q = -1.3, max_q = -0.1>                        


Epoch 359 / -1  [t=12531.67s]  <loss = 0.01758, min_q = -1.3, max_q = -0.1>                        
Epoch 360 / -1  [t=12560.52s]  <loss = 0.01742, min_q = -1.3, max_q = -0.1>                        
Epoch 361 / -1  [t=12606.82s]  <loss = 0.01735, min_q = -1.4, max_q = -0.1>                        
Epoch 362 / -1  [t=12636.27s]  <loss = 0.01714, min_q = -1.2, max_q = -0.1>                        
Epoch 363 / -1  [t=12665.60s]  <loss = 0.01876, min_q = -1.4, max_q = -0.1>                        
Epoch 364 / -1  [t=12695.16s]  <loss = 0.01817, min_q = -1.2, max_q = -0.1>                        
Epoch 365 / -1  [t=12724.60s]  <loss = 0.01787, min_q = -1.2, max_q = -0.1>                        
Epoch 366 / -1  [t=12754.16s]  <loss = 0.01810, min_q = -1.6, max_q = -0.1>                        
Epoch 367 / -1  [t=12783.77s]  <loss = 0.01735, min_q = -1.2, max_q = 0.3>                         
Epoch 368 / -1  [t=12813.18s]  <loss = 0.01755, min_q = -1.4, max_q = -0.1>                        


KeyboardInterrupt: 