In [3]:
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

import numpy as np
import tensorflow as tf
import angorapy as apy

env = apy.make_env("LunarLanderContinuous-v2")
distribution = apy.policies.BetaPolicyDistribution(env)

In [4]:
from tensorflow.keras.layers import TimeDistributed
from angorapy.utilities.util import env_extract_dims


def build_my_amazing_model(env, distribution, bs=1, sequence_length=1):
    state_dimensionality, n_actions = env_extract_dims(env)

    inputs = tf.keras.Input(batch_shape=(bs, sequence_length,) + state_dimensionality["proprioception"], name="proprioception")
    masked = tf.keras.layers.Masking(batch_input_shape=(bs, sequence_length,) + (inputs.shape[-1], ))(inputs)

    x = TimeDistributed(tf.keras.layers.Dense(32))(masked)
    x = TimeDistributed(tf.keras.layers.Dense(32))(x)

    x, *_ = tf.keras.layers.LSTM(64,
                       stateful=True,
                       return_sequences=True,
                       return_state=True,
                       batch_size=bs,
                       name="policy_recurrent_layer")(x)

    x_policy = tf.keras.layers.Dense(32)(x)
    x_policy = tf.keras.layers.Dense(32)(x_policy)

    x_value = tf.keras.layers.Dense(32)(x)
    x_value = tf.keras.layers.Dense(32)(x_value)

    out_policy = distribution.build_action_head(n_actions, x_policy.shape[1:], bs)(x_policy)
    out_value = tf.keras.layers.Dense(1)(x_value)

    policy = tf.keras.Model(inputs=inputs, outputs=out_policy, name="my_policy_function")
    value = tf.keras.Model(inputs=inputs, outputs=out_value, name="my_value_function")
    joint = tf.keras.Model(inputs=inputs, outputs=[out_policy, out_value], name="my_joint_networks")

    return policy, value, joint


In [5]:
agent = apy.Agent(build_my_amazing_model, env, horizon=1024, workers=1, distribution=distribution)
agent.drill(n=10, epochs=3, batch_size=32)
agent.save_agent_state()

2023-03-25 12:08:42.262399: I tensorflow/compiler/jit/xla_cpu_device.cc:41] Not creating XLA devices, tf_xla_enable_xla_devices not set
2023-03-25 12:08:42.264334: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcuda.so.1
2023-03-25 12:08:42.297482: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1720] Found device 0 with properties: 
pciBusID: 0000:15:00.0 name: NVIDIA GeForce RTX 2080 SUPER computeCapability: 7.5
coreClock: 1.845GHz coreCount: 48 deviceMemorySize: 7.78GiB deviceMemoryBandwidth: 462.00GiB/s
2023-03-25 12:08:42.297680: W tensorflow/stream_executor/platform/default/dso_loader.cc:60] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2023-03-25 12:08:42.304218: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcublas.so.11
2023-03-25 12:08:42.304354: I tensorflow/stream_execut

Detected 0 GPU devices.
Using [StateNormalizationTransformer, RewardNormalizationTransformer] for preprocessing.
An MPI Optimizer with 1 ranks has been created; the following ranks optimize: [0]


Drill started using 1 processes for 1 workers of which 1 are optimizers. Worker distribution: [1].
IDs over Workers: [[0]]
IDs over Optimizers: [[0]]

The policy is recurrent and the batch size is interpreted as the number of transitions per policy update. Given the batch size of 32 this results in: 
	2 chunks per update and 32 updates per epoch
	Batch tilings of (1, 2) per process and (1, 2) in total.




2023-03-25 12:08:43.576110: W tensorflow/python/util/util.cc:348] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.


Gathering cycle 0...

2023-03-25 12:08:44.558125: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:116] None of the MLIR optimization passes are enabled (registered 2)
2023-03-25 12:08:44.575742: I tensorflow/core/platform/profile_utils/cpu_utils.cc:112] CPU Frequency: 3600000000 Hz


[92mBefore Training[0m: r: [91m -174.88[0m; len: [94m  114.88[0m; n: [94m  8[0m; loss: [[94m  pi  [0m|[94m  v     [0m|[94m  ent [0m]; eps: [94m    0[0m; lr: [94m1.00e-03[0m; upd: [94m     0[0m; f: [94m   0.000[0mk; y.exp: [94m0.000[0m; w: [94m0.04[0m; times:  ; took s [unknown time left]; mem: 1.21/33|0.0/0.0;
Optimizing...

Optimizing...:   3%|▎         | 96/3072 [00:06<03:18, 14.97it/s] 


[92mCycle     1/10[0m: r: [91m -208.88[0m; len: [94m  108.00[0m; n: [94m  8[0m; loss: [[94m -1.02[0m|[94m    0.20[0m|[94m -0.14[0m]; eps: [94m    8[0m; lr: [94m1.00e-03[0m; upd: [94m    96[0m; f: [94m   1.024[0mk; w: [94m0.06[0m; times: [9.3|0.0|6.7] [58|0|42]; took 15.93s [2.4mins left]; mem: 1.24/33|0.0/0.0;
Optimizing...

Optimizing...:   3%|▎         | 96/3072 [00:05<02:59, 16.58it/s]


[92mCycle     2/10[0m: r: [91m -124.06[0m; len: [94m   99.56[0m; n: [94m  9[0m; loss: [[94m -0.92[0m|[94m    0.08[0m|[94m -0.17[0m]; eps: [94m   16[0m; lr: [94m1.00e-03[0m; upd: [94m   192[0m; f: [94m   2.048[0mk; w: [94m0.08[0m; times: [9.1|0.0|6.0] [60|0|40]; took 15.45s [2.1mins left]; mem: 1.25/33|0.0/0.0;
Optimizing...

Optimizing...:   3%|▎         | 96/3072 [00:06<03:10, 15.62it/s]


[92mCycle     3/10[0m: r: [91m -115.79[0m; len: [94m  114.00[0m; n: [94m  8[0m; loss: [[94m -0.98[0m|[94m    0.02[0m|[94m -0.12[0m]; eps: [94m   25[0m; lr: [94m1.00e-03[0m; upd: [94m   288[0m; f: [94m   3.072[0mk; w: [94m0.05[0m; times: [9.3|0.0|6.3] [59|0|41]; took 15.08s [1.8mins left]; mem: 1.25/33|0.0/0.0;
Optimizing...

Optimizing...:   3%|▎         | 96/3072 [00:06<03:09, 15.67it/s]


[92mCycle     4/10[0m: r: [91m  -73.92[0m; len: [94m  106.88[0m; n: [94m  8[0m; loss: [[94m -0.16[0m|[94m    0.01[0m|[94m -0.08[0m]; eps: [94m   33[0m; lr: [94m1.00e-03[0m; upd: [94m   384[0m; f: [94m   4.096[0mk; w: [94m0.04[0m; times: [8.6|0.0|6.3] [58|0|42]; took 15.38s [1.5mins left]; mem: 1.25/33|0.0/0.0;
Optimizing...

Optimizing...:   3%|▎         | 96/3072 [00:06<03:13, 15.41it/s]


[92mCycle     5/10[0m: r: [91m  -89.32[0m; len: [94m  101.00[0m; n: [94m  9[0m; loss: [[94m -0.64[0m|[94m    0.03[0m|[94m -0.14[0m]; eps: [94m   41[0m; lr: [94m1.00e-03[0m; upd: [94m   480[0m; f: [94m   5.120[0mk; w: [94m0.1[0m; times: [8.9|0.0|6.4] [58|0|42]; took 14.84s [1.3mins left]; mem: 1.26/33|0.0/0.0;
Optimizing...

Optimizing...:   3%|▎         | 96/3072 [00:06<03:08, 15.75it/s]


[92mCycle     6/10[0m: r: [91m  -86.04[0m; len: [94m  109.50[0m; n: [94m  8[0m; loss: [[94m -0.60[0m|[94m    0.01[0m|[94m -0.26[0m]; eps: [94m   50[0m; lr: [94m1.00e-03[0m; upd: [94m   576[0m; f: [94m   6.144[0mk; w: [94m0.05[0m; times: [8.2|0.0|6.3] [57|0|43]; took 15.11s [1.0mins left]; mem: 1.26/33|0.0/0.0;
Optimizing...

Optimizing...:   3%|▎         | 96/3072 [00:06<03:07, 15.88it/s]


[92mCycle     7/10[0m: r: [91m -112.56[0m; len: [94m  131.00[0m; n: [94m  7[0m; loss: [[94m -0.41[0m|[94m    0.06[0m|[94m -0.14[0m]; eps: [94m   58[0m; lr: [94m1.00e-03[0m; upd: [94m   672[0m; f: [94m   7.168[0mk; w: [94m0.06[0m; times: [8.7|0.0|6.2] [58|0|42]; took 15.46s [0.8mins left]; mem: 1.27/33|0.0/0.0;
Optimizing...

Optimizing...:   3%|▎         | 96/3072 [00:06<03:08, 15.82it/s]


[92mCycle     8/10[0m: r: [91m  -90.82[0m; len: [94m  123.29[0m; n: [94m  7[0m; loss: [[94m -0.44[0m|[94m    0.01[0m|[94m -0.39[0m]; eps: [94m   65[0m; lr: [94m1.00e-03[0m; upd: [94m   768[0m; f: [94m   8.192[0mk; w: [94m0.06[0m; times: [9.1|0.0|6.2] [59|0|41]; took 14.81s [0.5mins left]; mem: 1.27/33|0.0/0.0;
Optimizing...

Optimizing...:   3%|▎         | 96/3072 [00:06<03:09, 15.74it/s]


[92mCycle     9/10[0m: r: [91m  -68.38[0m; len: [94m  116.38[0m; n: [94m  8[0m; loss: [[94m -0.23[0m|[94m    0.01[0m|[94m -0.22[0m]; eps: [94m   72[0m; lr: [94m1.00e-03[0m; upd: [94m   864[0m; f: [94m   9.216[0mk; w: [94m0.04[0m; times: [8.4|0.0|6.3] [57|0|43]; took 14.96s [0.3mins left]; mem: 1.27/33|0.0/0.0;
Optimizing...

Optimizing...:   3%|▎         | 96/3072 [00:06<03:09, 15.73it/s]


Finalizing...Drill finished after 153.46serialization.


In [6]:
evaluation_results = agent.evaluate(10, act_confidently=True)[0]
print(np.mean(evaluation_results.episode_rewards))

-162.23110202608896
