# 2. Building Custom Models

When you are using AngoraPy to train goal-driven models of the brain, you will usually want these models to be _your_ custom networks. In this notebook, we show you the basics of constructing and registering your own model.

We first import necessary dependencies and then, like in the previous tutorial, build the environment and distribution.

In [4]:
import os

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

import numpy as np
import tensorflow as tf
import angorapy as ap

env = ap.make_env("CartPole-v1")
distribution = ap.policies.CategoricalPolicyDistribution(env)

If you would like to, you can also go for a different environment, for instance `LunarLander-v2`. However, training it will need a little more time and potentially a stronger network.

We now will build the model. In AngoraPy, we do not handle models itself, but instead operate on _model builders_. The reason for this is practicality in the backend of the library. Because we use truncated backpropagation through time, models need to be stateful (for an explanation of this, check the paper introducing AngoraPy). This requires models to be build with a specific sequence length. However, when collecting data, we want to do single steps, whereas we optimize on longer sequences. This requires us to constantly rebuild the model, thus demanding functions instead of objects. 

Anyways. The functions we write to build our models have some requirements.

1. Their signature mus follow the format `function(env, distribution, bs, sequence_length) -> policy, value, joint` where `env` and `distribution` are the environment and distribution the model will act upon, and `policy`, `value` and `joint` are the models (more about this below).
2. All recurrent elements need to return sequences and be stateful. In this notebook, however, we will start with a simple feedforward network. Only in the following notebook will we integrate a recurrent part.
3. It has to be registered in angorapy for later reference using th `ap.models.register_model("MODELNAME")` decorator.

We write a function builder for a 5-layer network, which partially shares weights between policy and value network. 

In [5]:
from tensorflow.keras.layers import TimeDistributed
from angorapy.utilities.model_utils import env_extract_dims, make_input_layers

@ap.models.register_model("MyModel")
def build_my_amazing_model(env, distribution, bs=1, sequence_length=None):
    inputs = make_input_layers(env, bs)["proprioception"]
    _, n_actions = env_extract_dims(env)

    
    x = tf.keras.layers.Dense(8, activation="relu")(inputs)
    x = tf.keras.layers.Dense(8, activation="relu")(x)
    x_policy = tf.keras.layers.Dense(8, activation="relu")(x)
    x_policy = tf.keras.layers.Dense(8, activation="relu")(x_policy)
    x_value = tf.keras.layers.Dense(8, activation="relu")(x)
    x_value = tf.keras.layers.Dense(8, activation="relu")(x_value)

    out_policy = distribution.build_action_head(n_actions, x_policy.shape[1:], bs)(x_policy)
    out_value = tf.keras.layers.Dense(1)(x_value)

    policy = tf.keras.Model(inputs=inputs, outputs=out_policy, name="my_policy_function")
    value = tf.keras.Model(inputs=inputs, outputs=out_value, name="my_value_function")
    joint = tf.keras.Model(inputs=inputs, outputs=[out_policy, out_value], name="my_joint_networks")

    return policy, value, joint

We then build the agent with our model function and plot the model for inspection. (Note that for the plotting part we need graphviz and pyplot installed on our machine)

In [6]:
from tensorflow.keras.utils import plot_model

agent = ap.Agent(build_my_amazing_model, env, horizon=2048, workers=1, distribution=distribution)
plot_model(agent.joint)

You must install pydot (`pip install pydot`) and install graphviz (see instructions at https://graphviz.gitlab.io/download/) for plot_model to work.


And thats all we need to now train our model on the task, as we have previously seen in the first notebook.

In [7]:
agent.drill(n=5, epochs=10, batch_size=32)
agent.save_agent_state()



Drill started using 1 processes for 1 workers of which 1 are optimizers. Worker distribution: [1].
IDs over Workers: [[0]]
IDs over Optimizers: [[0]]
Gathering cycle 0...

                                                                             

[92mBefore Training[0m; r: [91m   21.03[0m; len: [94m   21.03[0m; n: [94m 97[0m; loss: [[94m  pi  [0m|[94m  v     [0m|[94m  ent [0m]; upd: [94m     0[0m; y.exp: [94m0.000[0m; ; time:  ; time left: [94munknown time[0m; took s [unknown time left]


                                                                

Gathering cycle 1...

                                                                             

[92mCycle     1/10[0m; r: [91m   40.14[0m; len: [94m   40.14[0m; n: [94m 51[0m; loss: [[94m  0.03[0m|[94m    0.43[0m|[94m  0.68[0m]; upd: [94m   640[0m; ; time: [16.3|0.0|9.4] [63|0|37]; time left: [94m3.7mins[0m; took 24.76s [3.7mins left]


                                                                

Gathering cycle 2...

                                                                             

[92mCycle     2/10[0m; r: [91m   56.50[0m; len: [94m   56.50[0m; n: [94m 36[0m; loss: [[94m  0.07[0m|[94m    0.31[0m|[94m  0.64[0m]; upd: [94m  1280[0m; ; time: [15.0|0.0|8.4] [64|0|36]; time left: [94m3.3mins[0m; took 25.23s [3.3mins left]


                                                                

Gathering cycle 3...

                                                                             

[92mCycle     3/10[0m; r: [91m   92.62[0m; len: [94m   92.62[0m; n: [94m 21[0m; loss: [[94m  0.02[0m|[94m    0.17[0m|[94m  0.60[0m]; upd: [94m  1920[0m; ; time: [16.5|0.0|8.8] [65|0|35]; time left: [94m2.9mins[0m; took 25.1s [2.9mins left]


                                                                

Gathering cycle 4...

                                                                             

[92mCycle     4/10[0m; r: [91m  136.36[0m; len: [94m  136.36[0m; n: [94m 14[0m; loss: [[94m -0.04[0m|[94m    0.10[0m|[94m  0.57[0m]; upd: [94m  2560[0m; ; time: [16.0|0.0|9.1] [64|0|36]; time left: [94m2.5mins[0m; took 25.29s [2.5mins left]


                                                                

Gathering cycle 5...

                                                                             

[92mCycle     5/10[0m; r: [33m  328.00[0m; len: [94m  328.00[0m; n: [94m  6[0m; loss: [[94m -0.10[0m|[94m    0.06[0m|[94m  0.53[0m]; upd: [94m  3200[0m; ; time: [15.9|0.0|9.5] [63|0|37]; time left: [94m2.1mins[0m; took 24.87s [2.1mins left]


                                                                

Gathering cycle 6...

                                                                             

[92mCycle     6/10[0m; r: [91m  201.40[0m; len: [94m  201.40[0m; n: [94m 10[0m; loss: [[94m -0.05[0m|[94m    0.03[0m|[94m  0.49[0m]; upd: [94m  3840[0m; ; time: [15.1|0.0|8.3] [64|0|36]; time left: [94m1.7mins[0m; took 25.62s [1.7mins left]


                                                                

Gathering cycle 7...

                                                                             

[92mCycle     7/10[0m; r: [91m  192.20[0m; len: [94m  192.20[0m; n: [94m 10[0m; loss: [[94m  0.01[0m|[94m    0.03[0m|[94m  0.55[0m]; upd: [94m  4480[0m; ; time: [17.1|0.0|9.2] [65|0|35]; time left: [94m1.3mins[0m; took 25.02s [1.3mins left]


                                                                

Gathering cycle 8...

                                                                             

[92mCycle     8/10[0m; r: [91m  187.30[0m; len: [94m  187.30[0m; n: [94m 10[0m; loss: [[94m -0.01[0m|[94m    0.02[0m|[94m  0.53[0m]; upd: [94m  5120[0m; ; time: [15.5|0.0|8.9] [64|0|36]; time left: [94m0.8mins[0m; took 24.27s [0.8mins left]


                                                                

Gathering cycle 9...

                                                                             

[92mCycle     9/10[0m; r: [91m  176.64[0m; len: [94m  176.64[0m; n: [94m 11[0m; loss: [[94m -0.08[0m|[94m    0.01[0m|[94m  0.54[0m]; upd: [94m  5760[0m; ; time: [15.1|0.0|9.1] [62|0|38]; time left: [94m0.4mins[0m; took 25.01s [0.4mins left]


                                                                

Finalizing...Drill finished after 252.73serialization.


Once again, lets evaluate the agent to check how it performs without exploration.

In [8]:
evaluation_results = agent.evaluate(1, act_confidently=True)[0]
print(f"Mean performance after training: {np.mean(evaluation_results.episode_rewards)}")

100%|██████████| 1/1 [00:00<00:00,  1.41it/s]

Mean performance after training: 195.0





Now, since we might want to skip the training at a later stage and instead just load a previously saved agent, lets see how this works. We can load an agent from one of the states it has been saved at by calling the static `Agent.from_agent_state()` method which acts as a constructor. When training an agent using the `drill()` method, your model will be saved at every cycle, once as the _last_ agent state and once as the _best_ agent state if it is performaing better than the previous best state. Both of them are constantly overwritten, such that there is always two saved states for an agent. However, you can manually save an agent calling `agent.save_agent_state()` and additionally instruct the drill method to save at some frequency. In the following, we load the agent at the default state, `"best"`.

In [9]:
loaded_agent = ap.Agent.from_agent_state(agent.agent_id)

Loading from iteration 9.


KeyError: 'build_my_amazing_model'

We now have fully recovered the agent, however this agent holds the weights from the best version of itself (the original those from the last state). Lets evaluate this agent's alter ego.

In [16]:
evaluation_results = loaded_agent.evaluate(1, act_confidently=True)[0]
print(f"Mean performance after training: {np.mean(evaluation_results.episode_rewards)}")

[<KerasTensor: shape=(1, 4) dtype=float32 (created by layer 'proprioception')>]


100%|█████████████████████████████████████████████| 1/1 [00:01<00:00,  1.30s/it]

Mean performance after training: 500.0





This concludes the tutorial on model building. In the following notebook we will revisit the process, but show how to do it with a recurrent network.