# Batch Learning Tutorials
This tutorial will explain how to train a controller to perform the batch and growing batch learning tasks.

We will be using CartPole-v0, from the OpenAI gym (now, more correct: gymnasium), for the Batch tutorial, and CartPoleSway, a Psiori gym environment, for the Growing Batch tutorial.

In [1]:
import logging

import numpy as np

import tensorflow as tf
from tensorflow.keras import layers as tfkl

from psipy.rl.loop import Loop
from psipy.rl.core.controller import DiscreteRandomActionController, ContinuousRandomActionController
from psipy.rl.io.batch import Batch, Episode

LOG = logging.getLogger("psipy")

The following cell will test for the correct installation of the gymanasium. You should see a pygame window popup and some (random) trajectories on the cartpole plant. Please ignore the deprecation warning.

In [None]:
import gymnasium as gym
env = gym.make("CartPole-v0", render_mode="human")
observation, info = env.reset(seed=42)

for _ in range(500):
   action = env.action_space.sample()  # this is where you would insert your policy
   observation, reward, terminated, truncated, info = env.step(action)

   if terminated or truncated:
      observation, info = env.reset()

env.close()

## Background
There are two learning paradigms: batch and growing batch.
### Batch
In the batch paradigm, data is collected through some means, either a random controller, human, or otherwise.  A controller is fit on this data until convergence, and applied to the plant to test final performance.  In simple tasks, this should be sufficient.
<img src="batch-paradigm.png">
### Growing Batch
In the growing batch paradigm, initial data can either be collected the same way as in the batch approach above, or specifically through an exploration policy of the controller.  The controller is then fitted, and it is used to collect *more* data from the plant.  The 'batch of data' grows, and the controller is trained on this bigger batch.  In this way, the controller essentially explores the state space itself, and can 'learn from its mistakes'.
<img src="growing-paradigm.png">

## Batch Tutorial
It is best practice to lay out all your learning components for ease of use and reading.  Here, we define some placeholder variables for our plant, action, state, and lookback (how many observations are in a stack to create a state).

In [2]:
from psipy.rl.plants.gym.cartpole_plants import (
    CartPoleGymAction,
    CartPoleState,
    CartPolePlant,
)

plant = CartPolePlant(use_renderer=True)  # Note that it is instantiated!
action = CartPoleGymAction
state = CartPoleState
lookback = 2
sart_folder = "psidata-tutorial-batch-sart"

ORIGINAL GYM


  logger.deprecation(


We want to use NFQ to control this plant.  Therefore, we need a neural network internal model.  Below we use a function to create the model, but this is not necessary.  Note that we need to use the lookback here to make sure our network's inputs are properly shaped.

In [3]:
from psipy.rl.controllers.nfq import NFQ

def make_model(n_inputs, n_outputs, lookback):
    inp = tfkl.Input((n_inputs, lookback), name="state")
    net = tfkl.Flatten()(inp)
    net = tfkl.Dense(40, activation="tanh")(net)
    net = tfkl.Dense(40, activation="tanh")(net)
    net = tfkl.Dense(n_outputs, activation="sigmoid")(net)
    return tf.keras.Model(inp, net)

model = make_model(n_inputs=len(state.channels()), 
                   # CartPolev0 only has 1 action with 2 values (see CartPoleAction)
                   n_outputs=len(action.legal_values[0]), 
                   lookback=lookback)

# Create the controller with our model.  
controller = NFQ(model=model, state_channels=state.channels(), action=action, action_values=(0,1), lookback=lookback, clamp_terminal_costs=True)

Normalizer not fitted, returning values unchanged.


We also want a controller to explore.  We could use NFQ to explore, since with randomized weights it essentially acts as a random controller, but we will explicitly use a random controller here to demonstrate how to use different controllers at the same time. Since `CartPoleAction` is discrete, we use a `DiscreteRandomActionController`.

In [4]:
explorer = DiscreteRandomActionController(state_channels=state.channels(), action=action)

Now we create the `Loop`, which takes a name (the name in the SART logs), a plant, a controller, and a path to save the SART logs.

In [5]:
loop = Loop(plant, explorer, "GymCartPole", sart_folder, render=False)

Let's now collect some data with our explorer, the `DiscreteRandomActionController`.  We want to collect 50 episodes, and since the OpenAI gyms control for `max_episode_steps` already, we don't have to specify that parameter.

**Attention:** Please ignore the "ValueErrors" thrown by hdf5. 

In [6]:
loop.run(10)

Exception ignored in: <function ExpandableDataset.__del__ at 0x31b9a0ca0>
Traceback (most recent call last):
  File "/Users/slange/Code/psiori/psipy-public/psipy/rl/io/sart.py", line 114, in __del__
    self.finalize()
  File "/Users/slange/Code/psiori/psipy-public/psipy/rl/io/sart.py", line 147, in finalize
    self._resize(self.rows)
  File "/Users/slange/Code/psiori/psipy-public/psipy/rl/io/sart.py", line 121, in _resize
    raise e
  File "/Users/slange/Code/psiori/psipy-public/psipy/rl/io/sart.py", line 118, in _resize
    self.dataset.resize((rows, *self.incoming_shape))
  File "/Users/slange/Code/psiori/.venv/lib/python3.8/site-packages/h5py/_hl/dataset.py", line 659, in resize
    self.id.set_extent(size)
  File "h5py/_objects.pyx", line 54, in h5py._objects.with_phil.wrapper
  File "h5py/_objects.pyx", line 55, in h5py._objects.with_phil.wrapper
  File "h5py/h5d.pyx", line 277, in h5py.h5d.DatasetID.set_extent
ValueError: Invalid dataset identifier (invalid dataset identifier)

{1: {'total_cost': 0.13202781250473705,
  'cycles_run': 18,
  'wall_time_s': 1.5937},
 2: {'total_cost': 0.1626111401772743, 'cycles_run': 24, 'wall_time_s': 1.017},
 3: {'total_cost': 0.15032251799099866,
  'cycles_run': 25,
  'wall_time_s': 1.0376},
 4: {'total_cost': 0.10993117104801443,
  'cycles_run': 14,
  'wall_time_s': 0.802},
 5: {'total_cost': 0.15429976422719158,
  'cycles_run': 23,
  'wall_time_s': 1.0367},
 6: {'total_cost': 0.21805143973561636,
  'cycles_run': 27,
  'wall_time_s': 1.0856},
 7: {'total_cost': 0.12720278860690848,
  'cycles_run': 13,
  'wall_time_s': 0.7935},
 8: {'total_cost': 0.1274330436042014,
  'cycles_run': 19,
  'wall_time_s': 0.9156},
 9: {'total_cost': 0.1697660091641888,
  'cycles_run': 22,
  'wall_time_s': 0.9826},
 10: {'total_cost': 0.1342376697336599,
  'cycles_run': 14,
  'wall_time_s': 0.8163}}

Be aware that if you run this notebook multiple times, old data collected from previous runs will also be loaded unless you have deleted the SART folder.

We now load the data into a `Batch` from the hdf5 files we just created.  Be aware that you have to set the lookback here!

In [7]:
batch = Batch.from_hdf5(sart_folder, lookback=lookback, control=controller, action_channels=["move_index",])

Also note the logs: at the bottom, the logs will always tell you how many episodes were loaded.  If you want to know at any other point how many episodes the batch has loaded, check the `num_episodes` property.

In [8]:
batch.num_episodes

10

Neural networks like normalized data, so we fit NFQ's normalizer on the observations in the batch.  We have to pass in the batch's observations to fit the normalizer.

In [9]:
controller.fit_normalizer(batch.observations)


Now it is time for fitting.  We pass in the batch, and train.  This will take a couple minutes (not too long).

In [10]:
controller.fit(
    batch,
    iterations=2,
    epochs=10,
    minibatch_size=32,
    gamma=0.99,
    verbose=1,
)

qtargets n: 189 max: 1.0 min: 0.28722308948636055
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
qtargets n: 189 max: 1.0 min: 0.26617361325770617
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


Hurray!  Now we can see how our trained controller fairs live in the plant.  Let's run the controller again by creating a new `Loop`, but this time not store the data in our SART folder (otherwise if we want to train again we will train on this data as well, i.e. growing batch).  We do this by changing the `logdir` param to something different from our SART folder.  I prefer prepending "live-" to the SART folder name.  Finally, we render the environment so we can see what is happening.  Enjoy!

*Note: the environment will not close on its own. We know of this issue and won't fix it (yet) :D*

In [11]:
eval_loop = Loop(plant, controller, "CartPoleEval", f"{sart_folder}-evaluation", render=True)
eval_loop.run(5)

Exception ignored in: <function ExpandableDataset.__del__ at 0x31b9a0ca0>
Traceback (most recent call last):
  File "/Users/slange/Code/psiori/psipy-public/psipy/rl/io/sart.py", line 114, in __del__
    self.finalize()
  File "/Users/slange/Code/psiori/psipy-public/psipy/rl/io/sart.py", line 147, in finalize
    self._resize(self.rows)
  File "/Users/slange/Code/psiori/psipy-public/psipy/rl/io/sart.py", line 121, in _resize
    raise e
  File "/Users/slange/Code/psiori/psipy-public/psipy/rl/io/sart.py", line 118, in _resize
    self.dataset.resize((rows, *self.incoming_shape))
  File "/Users/slange/Code/psiori/.venv/lib/python3.8/site-packages/h5py/_hl/dataset.py", line 659, in resize
    self.id.set_extent(size)
  File "h5py/_objects.pyx", line 54, in h5py._objects.with_phil.wrapper
  File "h5py/_objects.pyx", line 55, in h5py._objects.with_phil.wrapper
  File "h5py/h5d.pyx", line 277, in h5py.h5d.DatasetID.set_extent
ValueError: Invalid dataset identifier (invalid dataset identifier)

{1: {'total_cost': 0.07455331778864403,
  'cycles_run': 10,
  'wall_time_s': 0.7246},
 2: {'total_cost': 0.07836104670841532,
  'cycles_run': 10,
  'wall_time_s': 0.7341},
 3: {'total_cost': 0.06271338455808305,
  'cycles_run': 10,
  'wall_time_s': 0.7312},
 4: {'total_cost': 0.22036167628293787,
  'cycles_run': 24,
  'wall_time_s': 1.0235},
 5: {'total_cost': 0.23307800374447404,
  'cycles_run': 25,
  'wall_time_s': 1.0496}}

The following will grow a batch of data interleaved with improving the model. This will take a while due to the network training inbetween the collection of episodes of interaction with the cartople. It should learn to balance the pole within 10 cycles of the outer loop or less.

In [None]:
num_cycles = 10
iterations = 5

loop = Loop(plant, controller, "GymCartPole", sart_folder, render=True)

for cycle in range(num_cycles):
    loop.run(5)

    batch.append_from_hdf5(sart_folder, action_channels=["move_index",])
    print(f"Current batch size: {batch.num_episodes}")

    controller.fit(
        batch,
        iterations=20,
        epochs=25,
        minibatch_size=64,
        gamma=0.99,
        verbose=1,
    )

Exception ignored in: <function ExpandableDataset.__del__ at 0x31b9a0ca0>
Traceback (most recent call last):
  File "/Users/slange/Code/psiori/psipy-public/psipy/rl/io/sart.py", line 114, in __del__
    self.finalize()
  File "/Users/slange/Code/psiori/psipy-public/psipy/rl/io/sart.py", line 147, in finalize
    self._resize(self.rows)
  File "/Users/slange/Code/psiori/psipy-public/psipy/rl/io/sart.py", line 121, in _resize
    raise e
  File "/Users/slange/Code/psiori/psipy-public/psipy/rl/io/sart.py", line 118, in _resize
    self.dataset.resize((rows, *self.incoming_shape))
  File "/Users/slange/Code/psiori/.venv/lib/python3.8/site-packages/h5py/_hl/dataset.py", line 659, in resize
    self.id.set_extent(size)
  File "h5py/_objects.pyx", line 54, in h5py._objects.with_phil.wrapper
  File "h5py/_objects.pyx", line 55, in h5py._objects.with_phil.wrapper
  File "h5py/h5d.pyx", line 277, in h5py.h5d.DatasetID.set_extent
ValueError: Invalid dataset identifier (invalid dataset identifier)

Current batch size: 15
qtargets n: 256 max: 1.0 min: 0.2847381653264165
Epoch 1/25
Epoch 2/25
Epoch 3/25
Epoch 4/25
Epoch 5/25
Epoch 6/25
Epoch 7/25
Epoch 8/25
Epoch 9/25
Epoch 10/25
Epoch 11/25
Epoch 12/25
Epoch 13/25
Epoch 14/25
Epoch 15/25
Epoch 16/25
Epoch 17/25
Epoch 18/25
Epoch 19/25
Epoch 20/25
Epoch 21/25
Epoch 22/25
Epoch 23/25
Epoch 24/25
Epoch 25/25
qtargets n: 256 max: 1.0 min: 0.3233471615239978
Epoch 1/25
Epoch 2/25
Epoch 3/25
Epoch 4/25
Epoch 5/25
Epoch 6/25
Epoch 7/25
Epoch 8/25
Epoch 9/25
Epoch 10/25
Epoch 11/25
Epoch 12/25
Epoch 13/25
Epoch 14/25
Epoch 15/25
Epoch 16/25
Epoch 17/25
Epoch 18/25
Epoch 19/25
Epoch 20/25
Epoch 21/25
Epoch 22/25
Epoch 23/25
Epoch 24/25
Epoch 25/25
qtargets n: 256 max: 1.0 min: 0.2684624781832099
Epoch 1/25
Epoch 2/25
Epoch 3/25
Epoch 4/25
Epoch 5/25
Epoch 6/25
Epoch 7/25
Epoch 8/25
Epoch 9/25
Epoch 10/25
Epoch 11/25
Epoch 12/25
Epoch 13/25
Epoch 14/25
Epoch 15/25
Epoch 16/25
Epoch 17/25
Epoch 18/25
Epoch 19/25
Epoch 20/25
Epoch 21/25
Epoch

Exception ignored in: <function ExpandableDataset.__del__ at 0x31b9a0ca0>
Traceback (most recent call last):
  File "/Users/slange/Code/psiori/psipy-public/psipy/rl/io/sart.py", line 114, in __del__
    self.finalize()
  File "/Users/slange/Code/psiori/psipy-public/psipy/rl/io/sart.py", line 147, in finalize
    self._resize(self.rows)
  File "/Users/slange/Code/psiori/psipy-public/psipy/rl/io/sart.py", line 121, in _resize
    raise e
  File "/Users/slange/Code/psiori/psipy-public/psipy/rl/io/sart.py", line 118, in _resize
    self.dataset.resize((rows, *self.incoming_shape))
  File "/Users/slange/Code/psiori/.venv/lib/python3.8/site-packages/h5py/_hl/dataset.py", line 659, in resize
    self.id.set_extent(size)
  File "h5py/_objects.pyx", line 54, in h5py._objects.with_phil.wrapper
  File "h5py/_objects.pyx", line 55, in h5py._objects.with_phil.wrapper
  File "h5py/h5d.pyx", line 277, in h5py.h5d.DatasetID.set_extent
ValueError: Invalid dataset identifier (invalid dataset identifier)

Current batch size: 20
qtargets n: 414 max: 1.0 min: 0.10188477300107479
Epoch 1/25
Epoch 2/25
Epoch 3/25
Epoch 4/25
Epoch 5/25
Epoch 6/25
Epoch 7/25
Epoch 8/25
Epoch 9/25
Epoch 10/25
Epoch 11/25
Epoch 12/25
Epoch 13/25
Epoch 14/25
Epoch 15/25
Epoch 16/25
Epoch 17/25
Epoch 18/25
Epoch 19/25
Epoch 20/25
Epoch 21/25
Epoch 22/25
Epoch 23/25
Epoch 24/25
Epoch 25/25
qtargets n: 414 max: 1.0 min: 0.11547465715557337
Epoch 1/25
Epoch 2/25
Epoch 3/25
Epoch 4/25
Epoch 5/25
Epoch 6/25
Epoch 7/25
Epoch 8/25
Epoch 9/25
Epoch 10/25
Epoch 11/25
Epoch 12/25
Epoch 13/25
Epoch 14/25
Epoch 15/25
Epoch 16/25
Epoch 17/25
Epoch 18/25
Epoch 19/25
Epoch 20/25
Epoch 21/25
Epoch 22/25
Epoch 23/25
Epoch 24/25
Epoch 25/25
qtargets n: 414 max: 1.0 min: 0.11841126624494791
Epoch 1/25
Epoch 2/25
Epoch 3/25
Epoch 4/25
Epoch 5/25
Epoch 6/25
Epoch 7/25
Epoch 8/25
Epoch 9/25
Epoch 10/25
Epoch 11/25
Epoch 12/25
Epoch 13/25
Epoch 14/25
Epoch 15/25
Epoch 16/25
Epoch 17/25
Epoch 18/25
Epoch 19/25
Epoch 20/25
Epoch 21/25
Ep

Exception ignored in: <function ExpandableDataset.__del__ at 0x31b9a0ca0>
Traceback (most recent call last):
  File "/Users/slange/Code/psiori/psipy-public/psipy/rl/io/sart.py", line 114, in __del__
    self.finalize()
  File "/Users/slange/Code/psiori/psipy-public/psipy/rl/io/sart.py", line 147, in finalize
    self._resize(self.rows)
  File "/Users/slange/Code/psiori/psipy-public/psipy/rl/io/sart.py", line 121, in _resize
    raise e
  File "/Users/slange/Code/psiori/psipy-public/psipy/rl/io/sart.py", line 118, in _resize
    self.dataset.resize((rows, *self.incoming_shape))
  File "/Users/slange/Code/psiori/.venv/lib/python3.8/site-packages/h5py/_hl/dataset.py", line 659, in resize
    self.id.set_extent(size)
  File "h5py/_objects.pyx", line 54, in h5py._objects.with_phil.wrapper
  File "h5py/_objects.pyx", line 55, in h5py._objects.with_phil.wrapper
  File "h5py/h5d.pyx", line 277, in h5py.h5d.DatasetID.set_extent
ValueError: Invalid dataset identifier (invalid dataset identifier)

Current batch size: 25
qtargets n: 1277 max: 1.0 min: 0.00594094954431057
Epoch 1/25
Epoch 2/25
Epoch 3/25
Epoch 4/25
Epoch 5/25
Epoch 6/25
Epoch 7/25
Epoch 8/25
Epoch 9/25
Epoch 10/25
Epoch 11/25
Epoch 12/25
Epoch 13/25
Epoch 14/25
Epoch 15/25
Epoch 16/25
Epoch 17/25
Epoch 18/25
Epoch 19/25
Epoch 20/25
Epoch 21/25
Epoch 22/25
Epoch 23/25
Epoch 24/25
Epoch 25/25
qtargets n: 1277 max: 1.0 min: 0.007434595376253128
Epoch 1/25
Epoch 2/25
Epoch 3/25
Epoch 4/25
Epoch 5/25
Epoch 6/25
Epoch 7/25
Epoch 8/25
Epoch 9/25
Epoch 10/25
Epoch 11/25
Epoch 12/25
Epoch 13/25
Epoch 14/25
Epoch 15/25
Epoch 16/25
Epoch 17/25
Epoch 18/25
Epoch 19/25
Epoch 20/25
Epoch 21/25
Epoch 22/25
Epoch 23/25
Epoch 24/25
Epoch 25/25
qtargets n: 1277 max: 1.0 min: 0.007667331490665674
Epoch 1/25
Epoch 2/25
Epoch 3/25
Epoch 4/25
Epoch 5/25
Epoch 6/25
Epoch 7/25
Epoch 8/25
Epoch 9/25
Epoch 10/25
Epoch 11/25
Epoch 12/25
Epoch 13/25
Epoch 14/25
Epoch 15/25
Epoch 16/25
Epoch 17/25
Epoch 18/25
Epoch 19/25
Epoch 20/25
Epoch 21/

Exception ignored in: <function ExpandableDataset.__del__ at 0x31b9a0ca0>
Traceback (most recent call last):
  File "/Users/slange/Code/psiori/psipy-public/psipy/rl/io/sart.py", line 114, in __del__
    self.finalize()
  File "/Users/slange/Code/psiori/psipy-public/psipy/rl/io/sart.py", line 147, in finalize
    self._resize(self.rows)
  File "/Users/slange/Code/psiori/psipy-public/psipy/rl/io/sart.py", line 121, in _resize
    raise e
  File "/Users/slange/Code/psiori/psipy-public/psipy/rl/io/sart.py", line 118, in _resize
    self.dataset.resize((rows, *self.incoming_shape))
  File "/Users/slange/Code/psiori/.venv/lib/python3.8/site-packages/h5py/_hl/dataset.py", line 659, in resize
    self.id.set_extent(size)
  File "h5py/_objects.pyx", line 54, in h5py._objects.with_phil.wrapper
  File "h5py/_objects.pyx", line 55, in h5py._objects.with_phil.wrapper
  File "h5py/h5d.pyx", line 277, in h5py.h5d.DatasetID.set_extent
ValueError: Invalid dataset identifier (invalid dataset identifier)

Current batch size: 30
qtargets n: 1586 max: 1.0 min: 0.038386713713407516
Epoch 1/25
Epoch 2/25
Epoch 3/25
Epoch 4/25
Epoch 5/25
Epoch 6/25
Epoch 7/25
Epoch 8/25
Epoch 9/25
Epoch 10/25
Epoch 11/25
Epoch 12/25
Epoch 13/25
Epoch 14/25
Epoch 15/25
Epoch 16/25
Epoch 17/25
Epoch 18/25
Epoch 19/25
Epoch 20/25
Epoch 21/25
Epoch 22/25
Epoch 23/25
Epoch 24/25
Epoch 25/25
qtargets n: 1586 max: 1.0 min: 0.030836308375000954
Epoch 1/25
Epoch 2/25
Epoch 3/25
Epoch 4/25
Epoch 5/25
Epoch 6/25
Epoch 7/25
Epoch 8/25
Epoch 9/25
Epoch 10/25
Epoch 11/25
Epoch 12/25
Epoch 13/25
Epoch 14/25
Epoch 15/25
Epoch 16/25
Epoch 17/25
Epoch 18/25
Epoch 19/25
Epoch 20/25
Epoch 21/25
Epoch 22/25
Epoch 23/25
Epoch 24/25
Epoch 25/25
qtargets n: 1586 max: 1.0 min: 0.03328525321558118
Epoch 1/25
Epoch 2/25
Epoch 3/25
Epoch 4/25
Epoch 5/25
Epoch 6/25
Epoch 7/25
Epoch 8/25
Epoch 9/25
Epoch 10/25
Epoch 11/25
Epoch 12/25
Epoch 13/25
Epoch 14/25
Epoch 15/25
Epoch 16/25
Epoch 17/25
Epoch 18/25
Epoch 19/25
Epoch 20/25
Epoch 21/

Exception ignored in: <function ExpandableDataset.__del__ at 0x31b9a0ca0>
Traceback (most recent call last):
  File "/Users/slange/Code/psiori/psipy-public/psipy/rl/io/sart.py", line 114, in __del__
    self.finalize()
  File "/Users/slange/Code/psiori/psipy-public/psipy/rl/io/sart.py", line 147, in finalize
    self._resize(self.rows)
  File "/Users/slange/Code/psiori/psipy-public/psipy/rl/io/sart.py", line 121, in _resize
    raise e
  File "/Users/slange/Code/psiori/psipy-public/psipy/rl/io/sart.py", line 118, in _resize
    self.dataset.resize((rows, *self.incoming_shape))
  File "/Users/slange/Code/psiori/.venv/lib/python3.8/site-packages/h5py/_hl/dataset.py", line 659, in resize
    self.id.set_extent(size)
  File "h5py/_objects.pyx", line 54, in h5py._objects.with_phil.wrapper
  File "h5py/_objects.pyx", line 55, in h5py._objects.with_phil.wrapper
  File "h5py/h5d.pyx", line 277, in h5py.h5d.DatasetID.set_extent
ValueError: Invalid dataset identifier (invalid dataset identifier)

Current batch size: 35
qtargets n: 2144 max: 1.0 min: 0.05509443115442991
Epoch 1/25
Epoch 2/25
Epoch 3/25
Epoch 4/25
Epoch 5/25
Epoch 6/25
Epoch 7/25
Epoch 8/25
Epoch 9/25
Epoch 10/25
Epoch 11/25
Epoch 12/25
Epoch 13/25
Epoch 14/25
Epoch 15/25
Epoch 16/25
Epoch 17/25
Epoch 18/25
Epoch 19/25
Epoch 20/25
Epoch 21/25
Epoch 22/25
Epoch 23/25
Epoch 24/25
Epoch 25/25
qtargets n: 2144 max: 1.0 min: 0.0537619236856699
Epoch 1/25
Epoch 2/25
Epoch 3/25
Epoch 4/25
Epoch 5/25
Epoch 6/25
Epoch 7/25
Epoch 8/25
Epoch 9/25
Epoch 10/25
Epoch 11/25
Epoch 12/25
Epoch 13/25
Epoch 14/25
Epoch 15/25
Epoch 16/25
Epoch 17/25
Epoch 18/25
Epoch 19/25
Epoch 20/25
Epoch 21/25
Epoch 22/25
Epoch 23/25
Epoch 24/25
Epoch 25/25
qtargets n: 2144 max: 1.0 min: 0.0565631405916065
Epoch 1/25
Epoch 2/25
Epoch 3/25
Epoch 4/25
Epoch 5/25
Epoch 6/25
Epoch 7/25
Epoch 8/25
Epoch 9/25
Epoch 10/25
Epoch 11/25
Epoch 12/25
Epoch 13/25
Epoch 14/25
Epoch 15/25
Epoch 16/25
Epoch 17/25
Epoch 18/25
Epoch 19/25
Epoch 20/25
Epoch 21/25
E

**Attention:** everything below has not yet been adapted to psipy-public and will not work right now. WIP.

## Growing Batch Tutorial
We will use the `CartPoleSway` env here, since `CartPole-v0` has a hardcoded discrete action space.

As with the Batch tutorial, we will lay out all of our learning components for ease of use and reading.  Here, we define some placeholder variables for our plant, action, state, and lookback (how many observations are in a stack to create a state).

In [None]:
from psipy.rl.plant.gym.cartpole_plants import (
    CartPoleSwayContAction,
    CartPoleSwayContinuousPlant,
    CartPoleSwayState,
)

plant = CartPoleSwayContinuousPlant()  # Note that it is instantiated!
action = CartPoleSwayContAction
state = CartPoleSwayState
lookback = 1
sart_folder = "tutorial-growing-sart"

We want to use NFQCA to control this plant.  Therefore, we need a two neural network internal models, one for the actor and one for the critic.  Below we use functions to create the models, but this is not necessary.  Note that we need to use the lookback to properly shape our neural networks' inputs.

In [None]:
from psipy.rl.control.nfqca import NFQCA

def make_actor(inputs, lookback):
    inp = tfkl.Input((inputs, lookback), name="state_actor")
    net = tfkl.Flatten()(inp)
    net = tfkl.Dense(40, activation="tanh")(net)
    net = tfkl.Dense(40, activation="tanh")(net)
    net = tfkl.Dense(1, activation="tanh")(net)
    return tf.keras.Model(inp, net, name="actor")


def make_critic(inputs, lookback):
    inp = tfkl.Input((inputs, lookback), name="state_critic")
    act = tfkl.Input((1,), name="act_in")
    net = tfkl.Concatenate()([tfkl.Flatten()(inp), tfkl.Flatten()(act)])
    net = tfkl.Dense(40, activation="tanh")(net)
    net = tfkl.Dense(40, activation="tanh")(net)
    net = tfkl.Dense(1, activation="sigmoid")(net)
    return tf.keras.Model([inp, act], net, name="critic")

actor = make_actor(len(state.channels()), lookback)
critic = make_critic(len(state.channels()), lookback)

controller = NFQCA(
    actor=actor, 
    critic=critic, 
    state_channels=state.channels(), 
    action=CartPoleSwayContAction,
    lookback=lookback
)

Now we create the `Loop`, which takes a name (the name in the SART logs), a plant, a controller, and a path to save the SART logs.

In [None]:
loop = Loop(plant, controller, "CartPoleSway", sart_folder)

Now we will use NFQCA to do the initial exploration, as well as all data collection.  We first collect some initial data outside the loop and then collect more data within the growing-batch-loop.  We print some extra things so you can see what is going on, but that is unnecessary.

Note that depending on the internal cost function of the `CartPoleSwayEnv` at this time, NFQCA might learn nothing.  This tutorial just shows the general form of training NFQCA.

In [None]:
num_cycles = 3
iterations = 2

loop.run(10)
batch = Batch.from_hdf5(sart_folder, lookback=lookback, control=controller)

for cycle in range(num_cycles):
    LOG.info("Cycle: %d", cycle + 1)
    print(f"Current batch size: {batch.num_episodes}")
    # Fit the normalizer on the data. Fitting iteratively makes the fit 
    # parameters hone in on the true population parameters 
    # (See Batch Tutorial above for more detail on how normalization works)
    controller.fit_normalizer(batch.observations, method="meanstd")

    # NFQCA does not have a generic fit method
    for iteration in range(iterations):
        LOG.info("NFQCA Iteration: %d", iteration + 1)
        controller.fit_critic(batch,
                         iterations=1,
                         epochs=10,
                         minibatch_size=-1,
                         gamma=1.0,
                         verbose=0)
        controller.fit_actor(batch,
                        epochs=10,
                        minibatch_size=-1,
                        verbose=0)

    loop = Loop(plant, controller, "GrowingBatch", sart_folder)
    loop.run(5)
    # Batch.append_from_hdf5() appends any new files found in the folder
    batch.append_from_hdf5(sart_folder)

Now we can see how the model performs live.

In [None]:
loop = Loop(plant, controller, "GrowingBatchEval", f"live-{sart_folder}", render=True)
loop.run(5)

## Advanced Learning Tutorial

You've made it to the next level! Congratulations.  Now it is time to delve deeper into the power of offline reinforcement learning.  The general guidelines training will not be outlined here.

Let's train `CartPoleSway` again, but this time alter the cost function and create fake transitions to aid the cart to its goal.

### Problem Setting
We want the cart to move the position 0 (middle of the screen) from any starting position.  We do not care about the pole.

For this, we will generate a cost function that only deals out cost based on position, and create a fake transition set that shows 0 cost at the goal position.

*Note the imports inside the functions: this is bad practice but I do it here to show what imports we need.*

In [None]:
def create_fake_episodes(sart_path:str, lookback:int): # ->List[Episode]
    """Create a fake episode at position 0 for every episode already collected
    
    Since the cart probably was never at position 0 exactly, we add a set of Episodes with this transition.
    """
    import glob
    from psipy.rl.io.sart import SARTReader
    from psipy.rl.io.batch import Episode
    
    more_episodes = []

    for path in glob.glob(f"{sart_path}/*.hdf5"):
        with SARTReader(path) as reader:
            o, a, t, c = reader.load_full_episode()

            # Add episode full of goal states
            o = o.copy()
            a = a.copy()
            o[:, 0] = 0
            o[:, 1] = 0
            # A swinging pole affects the position of the cart,
            # so we say no swing here as well
            o[:, 2] = 180
            o[:, 3] = 0
            a[:] = 0  # The cart should not move once in this position
            more_episodes.append(Episode(o, a, t, c, lookback=lookback))
            
    return more_episodes

def costfunc(states:np.ndarray): # -> np.ndarray
    """Recalculate costs on all states provided
    
    This calculates costs on multiple states, so it returns an array
    """
    from psipy.rl.control.nfq import tanh2
    # Position is already defined against 0 (relative)
    position = states[:, 0]
    cost = tanh2(position, C=.2, mu=.1)
    
    return cost

Let's set everything up until we need to add the fake transitions.

In [None]:
from psipy.rl.plant.gym.cartpole_plants import (
    CartPoleSwayContAction,
    CartPoleSwayContinuousPlant,
    CartPoleSwayState,
)
from psipy.rl.control.nfqca import NFQCA


plant = CartPoleSwayContinuousPlant()  # Note that it is instantiated!
action = CartPoleSwayContAction
state = CartPoleSwayState
lookback = 5
sart_folder = "tutorial-advanced-sart"


def make_actor(inputs, lookback):
    inp = tfkl.Input((inputs, lookback), name="state_actor")
    net = tfkl.Flatten()(inp)
    net = tfkl.Dense(40, activation="tanh")(net)
    net = tfkl.Dense(40, activation="tanh")(net)
    net = tfkl.Dense(1, activation="tanh")(net)
    return tf.keras.Model(inp, net, name="actor")


def make_critic(inputs, lookback):
    inp = tfkl.Input((inputs, lookback), name="state_critic")
    act = tfkl.Input((1,), name="act_in")
    net = tfkl.Concatenate()([tfkl.Flatten()(inp), tfkl.Flatten()(act)])
    net = tfkl.Dense(40, activation="tanh")(net)
    net = tfkl.Dense(40, activation="tanh")(net)
    net = tfkl.Dense(1, activation="sigmoid")(net)
    return tf.keras.Model([inp, act], net, name="critic")

actor = make_actor(len(state.channels()), lookback)
critic = make_critic(len(state.channels()), lookback)

controller = NFQCA(
    actor=actor, 
    critic=critic, 
    state_channels=state.channels(), 
    action=CartPoleSwayContAction,
    lookback=lookback
)

loop = Loop(plant, controller, "CartPolePosition", sart_folder)

The critic recieves the cost function, and before we start training we want to add our fake goal position transitions.  Therefore, we put `costfunc` in `fit_critic` and append fake episodes before we start the fitting cycles.

We will print the size of the batch so it can be seen how the fake episodes increase the batch episode count.

In [None]:
num_cycles = 3
iterations = 3

loop.run(10)
batch = Batch.from_hdf5(sart_folder, lookback=lookback, control=controller)
print(f"Current batch size: {batch.num_episodes}")
# We now append the fake created episodes
# We could also throw away the created batch and only have the episodes
# created in the function by doing:
# batch = Batch(create_fake_episodes(sart_folder, lookback))
batch = batch.append(create_fake_episodes(sart_folder, lookback))
        
for cycle in range(num_cycles):
    LOG.info("Cycle: %d", cycle + 1)
    print(f"Current batch size: {batch.num_episodes}")
    controller.fit_normalizer(batch.observations, method="meanstd")

    for iteration in range(iterations):
        LOG.info("NFQCA Iteration: %d", iteration + 1)
        controller.fit_critic(batch,
                         iterations=1,
                         # We add the cost function here
                         costfunc=costfunc,
                         epochs=10,
                         minibatch_size=-1,
                         gamma=1.0,
                         verbose=0)
        controller.fit_actor(batch,
                        epochs=10,
                        minibatch_size=-1,
                        verbose=0)

    loop = Loop(plant, controller, "GrowingBatch", sart_folder)
    loop.run(10)
    batch.append_from_hdf5(sart_folder)

Let's see if the controller moves to the goal position.  Oh the tension!

In [None]:
loop = Loop(plant, controller, "GrowingBatchEval", f"live-{sart_folder}", render=True)
loop.run(3)

Congratulations, you've made it through the Advanced Tutorial.  You are now an expert!

## Please delete the SART logs created by this tutorial if you no longer need them.  Or run the cell below to do that automatically.

In [None]:
import shutil
dirs = ["tutorial-advanced-sart", 
        "tutorial-batch-sart", 
        "tutorial-growing-batch", 
        "live-tutorial-growing-sart", 
        "live-tutorial-batch-sart", 
        "live-tutorial-advanced-sart"]
for d in dirs:
    try:
        shutil.rmtree(d)
    except Exception as e:
        print(e)