# Batch Learning Tutorials
This tutorial will explain how to train a controller to perform the batch and growing batch learning tasks.

We will be using CartPole-v0, from the OpenAI gym, for the Batch tutorial, and CartPoleSway, a Psiori gym environment, for the Growing Batch tutorial.

In [1]:
import logging

import numpy as np

import tensorflow as tf
from tensorflow.keras import layers as tfkl

from psipy.rl.loop import Loop
from psipy.rl.control.controller import DiscreteRandomActionController, ContinuousRandomActionController
from psipy.rl.io.batch import Batch, Episode

LOG = logging.getLogger("psipy")

In [None]:
import gymnasium as gym
env = gym.make("CartPole-v0", render_mode="human")
observation, info = env.reset(seed=42)

for _ in range(1000):
   action = env.action_space.sample()  # this is where you would insert your policy
   observation, reward, terminated, truncated, info = env.step(action)

   if terminated or truncated:
      observation, info = env.reset()

env.close()

## Background
There are two learning paradigms: batch and growing batch.
### Batch
In the batch paradigm, data is collected through some means, either a random controller, human, or otherwise.  A controller is fit on this data until convergence, and applied to the plant to test final performance.  In simple tasks, this should be sufficient.
<img src="batch-paradigm.png">
### Growing Batch
In the growing batch paradigm, initial data can either be collected the same way as in the batch approach above, or specifically through an exploration policy of the controller.  The controller is then fitted, and it is used to collect *more* data from the plant.  The 'batch of data' grows, and the controller is trained on this bigger batch.  In this way, the controller essentially explores the state space itself, and can 'learn from its mistakes'.
<img src="growing-paradigm.png">

## Batch Tutorial
It is best practice to lay out all your learning components for ease of use and reading.  Here, we define some placeholder variables for our plant, action, state, and lookback (how many observations are in a stack to create a state).

In [2]:
from psipy.rl.plant.gym.cartpole_plants import (
    CartPoleGymAction,
    CartPoleState,
    CartPolePlant,
)

plant = CartPolePlant(use_renderer=True)  # Note that it is instantiated!
action = CartPoleGymAction
state = CartPoleState
lookback = 2
sart_folder = "data_tutorial-batch-sart"

ORIGINAL GYM


  logger.deprecation(


We want to use NFQ to control this plant.  Therefore, we need a neural network internal model.  Below we use a function to create the model, but this is not necessary.  Note that we need to use the lookback here to make sure our network's inputs are properly shaped.

In [3]:
from psipy.rl.control.nfq import NFQ

def make_model(n_inputs, n_outputs, lookback):
    inp = tfkl.Input((n_inputs, lookback), name="state")
    net = tfkl.Flatten()(inp)
    net = tfkl.Dense(40, activation="tanh")(net)
    net = tfkl.Dense(40, activation="tanh")(net)
    net = tfkl.Dense(n_outputs, activation="sigmoid")(net)
    return tf.keras.Model(inp, net)

model = make_model(n_inputs=len(state.channels()), 
                   # CartPolev0 only has 1 action with 2 values (see CartPoleAction)
                   n_outputs=len(action.legal_values[0]), 
                   lookback=lookback)

# Create the controller with our model.  
controller = NFQ(model=model, state_channels=state.channels(), action=action, action_values=(0,1), lookback=lookback)

Normalizer not fitted, returning values unchanged.


We also want a controller to explore.  We could use NFQ to explore, since with randomized weights it essentially acts as a random controller, but we will explicitly use a random controller here to demonstrate how to use different controllers at the same time. Since `CartPoleAction` is discrete, we use a `DiscreteRandomActionController`.

In [4]:
explorer = DiscreteRandomActionController(state_channels=state.channels(), action=action)

Now we create the `Loop`, which takes a name (the name in the SART logs), a plant, a controller, and a path to save the SART logs.

In [5]:
loop = Loop(plant, explorer, "GymCartPole", sart_folder, render=False)

Let's now collect some data with our explorer, the `DiscreteRandomActionController`.  We want to collect 50 episodes, and since the OpenAI gyms control for `max_episode_steps` already, we don't have to specify that parameter.

In [None]:
import gymnasium as gym
env = gym.make("CartPole-v0", render_mode="human")
observation, info = env.reset(seed=42)

env.action_space.contains(1)

In [6]:
loop.run(5)

2024-09-05 14:08:29.310 Python[30318:861814] ApplePersistenceIgnoreState: Existing state will not be touched. New state will be written to /var/folders/3d/_pkw6ngs6bd4cdys6p7g7wf40000gn/T/org.python.python.savedState
2024-09-05 14:08:29.696 Python[30318:861814] +[IMKClient subclass]: chose IMKClient_Legacy
2024-09-05 14:08:29.696 Python[30318:861814] +[IMKInputSession subclass]: chose IMKInputSession_Legacy


INIT ACTION ABSTRACT BASE CLASS
{'move': 1}
DICT
('move',)
dict_keys(['move'])
ACTION:
{'move': 1}
INIT ACTION ABSTRACT BASE CLASS
{'move': 1}
DICT
('move',)
dict_keys(['move'])
INIT ACTION ABSTRACT BASE CLASS
{'move': 1}
DICT
('move',)
dict_keys(['move'])
ACTION:
{'move': 1}
INIT ACTION ABSTRACT BASE CLASS
{'move': 1}
DICT
('move',)
dict_keys(['move'])
PATH:
'state/values/cart_position'
PATH:
'state/values/cart_velocity'
PATH:
'state/values/pole_angle'
PATH:
'state/values/pole_velocity'
PATH:
'state/values/move_ACT'
PATH:
'state/cost'
PATH:
'state/terminal'
PATH:
'action/move'
INIT ACTION ABSTRACT BASE CLASS
{'move': 1}
DICT
('move',)
dict_keys(['move'])
ACTION:
{'move': 1}
INIT ACTION ABSTRACT BASE CLASS
{'move': 1}
DICT
('move',)
dict_keys(['move'])
INIT ACTION ABSTRACT BASE CLASS
{'move': 1}
DICT
('move',)
dict_keys(['move'])
ACTION:
{'move': 1}
INIT ACTION ABSTRACT BASE CLASS
{'move': 1}
DICT
('move',)
dict_keys(['move'])
INIT ACTION ABSTRACT BASE CLASS
{'move': 1}
DICT
('move',)


{1: {'total_cost': 0.0843443163353184,
  'cycles_run': 11,
  'wall_time_s': 1.5995},
 2: {'total_cost': 0.20886756889336489,
  'cycles_run': 22,
  'wall_time_s': 0.9911},
 3: {'total_cost': 0.11888107701482263,
  'cycles_run': 18,
  'wall_time_s': 0.8947},
 4: {'total_cost': 0.12672605589089384,
  'cycles_run': 16,
  'wall_time_s': 0.8548},
 5: {'total_cost': 0.0925760603046529,
  'cycles_run': 12,
  'wall_time_s': 0.7652}}

In [7]:
from pprint import pprint

a = action({'move': 1})
pprint (a.as_dict())

INIT ACTION ABSTRACT BASE CLASS
{'move': 1}
DICT
('move',)
dict_keys(['move'])
{'move': 1}


Exception ignored in: <function ExpandableDataset.__del__ at 0x305ba5f70>
Traceback (most recent call last):
  File "/Users/slange/Code/sabbatical/psiori/psipy-public/psipy/rl/io/sart.py", line 118, in __del__
    self.finalize()
  File "/Users/slange/Code/sabbatical/psiori/psipy-public/psipy/rl/io/sart.py", line 153, in finalize
    self._resize(self.rows)
  File "/Users/slange/Code/sabbatical/psiori/psipy-public/psipy/rl/io/sart.py", line 127, in _resize
    raise e
  File "/Users/slange/Code/sabbatical/psiori/psipy-public/psipy/rl/io/sart.py", line 124, in _resize
    self.dataset.resize((rows, *self.incoming_shape))
  File "/Users/slange/Code/sabbatical/psiori/.venv/lib/python3.8/site-packages/h5py/_hl/dataset.py", line 659, in resize
    self.id.set_extent(size)
  File "h5py/_objects.pyx", line 54, in h5py._objects.with_phil.wrapper
  File "h5py/_objects.pyx", line 55, in h5py._objects.with_phil.wrapper
  File "h5py/h5d.pyx", line 277, in h5py.h5d.DatasetID.set_extent
ValueError: 

RESIZE!
(17,)
RESIZE!
(17,)
RESIZE!
(17,)
RESIZE!
(17,)
RESIZE!
(17,)
RESIZE!
(17,)
RESIZE!
(17,)
RESIZE!
(17,)
RESIZE!
(19,)
RESIZE!
(19,)
RESIZE!
(19,)
RESIZE!
(19,)
RESIZE!
(19,)
RESIZE!
(19,)
RESIZE!
(19,)
RESIZE!
(19,)
RESIZE!
(23,)
RESIZE!
(23,)
RESIZE!
(23,)
RESIZE!
(23,)
RESIZE!
(23,)
RESIZE!
(23,)
RESIZE!
(23,)
RESIZE!
(23,)
RESIZE!
(12,)
RESIZE!
(12,)
RESIZE!
(12,)
RESIZE!
(12,)
RESIZE!
(12,)
RESIZE!
(12,)
RESIZE!
(12,)
RESIZE!
(12,)


Be aware that if you run this notebook multiple times, old data collected from previous runs will also be loaded unless you have deleted the SART folder.

We now load the data into a `Batch` from the hdf5 files we just created.  Be aware that you have to set the lookback here!

In [None]:
import h5py
import os

files = os.listdir("tutorial-batch-sart/")

filename = files[0]
print(f"file {filename}")
f = h5py.File(os.path.join("tutorial-batch-sart/", filename), 'r')
print(f.keys())
print(f['action'].keys())
print(f['action']['move'])
print(f['action']['move'][0])

In [8]:
batch = Batch.from_hdf5(sart_folder, lookback=lookback, control=controller)

Also note the logs: at the bottom, the logs will always tell you how many episodes were loaded.  If you want to know at any other point how many episodes the batch has loaded, check the `num_episodes` property.

In [9]:
batch.num_episodes

5

Neural networks like normalized data, so we fit NFQ's normalizer on the observations in the batch.  We have to pass in the batch's observations to fit the normalizer.

In [10]:
controller.fit_normalizer(batch.observations)


Now it is time for fitting.  We pass in the batch, and train.  This will take a couple minutes (not too long).

In [11]:
from pprint import pprint
pprint(batch.states_actions[0])

(array([[[ 0.45524758,  0.44397676],
        [-0.06777441,  0.23303981],
        [-0.00843003, -0.00700454],
        [-0.06461398, -0.32833362],
        [-1.0127394 ,  0.9874209 ]],

       [[ 0.44397676,  0.48202646],
        [ 0.23303981,  0.53385454],
        [-0.00700454, -0.05634047],
        [-0.32833362, -0.5920086 ],
        [ 0.9874209 ,  0.9874209 ]],

       [[ 0.48202646,  0.56939673],
        [ 0.53385454,  0.834809  ],
        [-0.05634047, -0.15642917],
        [-0.5920086 , -0.8573697 ],
        [ 0.9874209 ,  0.9874209 ]],

       [[ 0.56939673,  0.7061105 ],
        [ 0.834809  ,  1.136029  ],
        [-0.15642917, -0.30759528],
        [-0.8573697 , -1.1261109 ],
        [ 0.9874209 ,  0.9874209 ]],

       [[ 0.7061105 ,  0.8922112 ],
        [ 1.136029  ,  1.4376153 ],
        [-0.30759528, -0.51048934],
        [-1.1261109 , -1.3998632 ],
        [ 0.9874209 ,  0.9874209 ]],

       [[ 0.8922112 ,  1.1277591 ],
        [ 1.4376153 ,  1.73963   ],
        [-0.51048

In [12]:
controller.fit(
    batch,
    iterations=2,
    epochs=10,
    minibatch_size=32,
    gamma=1.0,
    verbose=1,
)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


Hurray!  Now we can see how our trained controller fairs live in the plant.  Let's run the controller again by creating a new `Loop`, but this time not store the data in our SART folder (otherwise if we want to train again we will train on this data as well, i.e. growing batch).  We do this by changing the `logdir` param to something different from our SART folder.  I prefer prepending "live-" to the SART folder name.  Finally, we render the environment so we can see what is happening.  Enjoy!

*Note: the environment will not close on its own. We know of this issue and won't fix it (yet) :D*

In [13]:
eval_loop = Loop(plant, controller, "CartPoleEval", f"{sart_folder}-evaluation", render=True)
eval_loop.run(5)

INIT ACTION ABSTRACT BASE CLASS
{'move': 1.0}
DICT
('move',)
dict_keys(['move'])
ACTION:
{'move': 1.0}
INIT ACTION ABSTRACT BASE CLASS
{'move': 1}
DICT
('move',)
dict_keys(['move'])
PATH:
'state/values/cart_position'
PATH:
'state/values/cart_velocity'
PATH:
'state/values/pole_angle'
PATH:
'state/values/pole_velocity'
PATH:
'state/values/move_ACT'
PATH:
'state/cost'
PATH:
'state/terminal'
PATH:
'action/move'
INIT ACTION ABSTRACT BASE CLASS
{'move': 0.0}
DICT
('move',)
dict_keys(['move'])
ACTION:
{'move': 0.0}
INIT ACTION ABSTRACT BASE CLASS
{'move': 0}
DICT
('move',)
dict_keys(['move'])
INIT ACTION ABSTRACT BASE CLASS
{'move': 0.0}
DICT
('move',)
dict_keys(['move'])
ACTION:
{'move': 0.0}
INIT ACTION ABSTRACT BASE CLASS
{'move': 0}
DICT
('move',)
dict_keys(['move'])
INIT ACTION ABSTRACT BASE CLASS
{'move': 0.0}
DICT
('move',)
dict_keys(['move'])
ACTION:
{'move': 0.0}
INIT ACTION ABSTRACT BASE CLASS
{'move': 0}
DICT
('move',)
dict_keys(['move'])
INIT ACTION ABSTRACT BASE CLASS
{'move': 0.

{1: {'total_cost': 0.09065281760477525,
  'cycles_run': 12,
  'wall_time_s': 0.7902},
 2: {'total_cost': 0.06995367967053981,
  'cycles_run': 12,
  'wall_time_s': 0.7696},
 3: {'total_cost': 0.06231771877360675,
  'cycles_run': 10,
  'wall_time_s': 0.7305},
 4: {'total_cost': 0.11924910411321367,
  'cycles_run': 14,
  'wall_time_s': 0.8085},
 5: {'total_cost': 0.06682355553658259,
  'cycles_run': 10,
  'wall_time_s': 0.7256}}

Exception ignored in: <function ExpandableDataset.__del__ at 0x305ba5f70>
Traceback (most recent call last):
  File "/Users/slange/Code/sabbatical/psiori/psipy-public/psipy/rl/io/sart.py", line 118, in __del__
    self.finalize()
  File "/Users/slange/Code/sabbatical/psiori/psipy-public/psipy/rl/io/sart.py", line 153, in finalize
    self._resize(self.rows)
  File "/Users/slange/Code/sabbatical/psiori/psipy-public/psipy/rl/io/sart.py", line 127, in _resize
    raise e
  File "/Users/slange/Code/sabbatical/psiori/psipy-public/psipy/rl/io/sart.py", line 124, in _resize
    self.dataset.resize((rows, *self.incoming_shape))
  File "/Users/slange/Code/sabbatical/psiori/.venv/lib/python3.8/site-packages/h5py/_hl/dataset.py", line 659, in resize
    self.id.set_extent(size)
  File "h5py/_objects.pyx", line 54, in h5py._objects.with_phil.wrapper
  File "h5py/_objects.pyx", line 55, in h5py._objects.with_phil.wrapper
  File "h5py/h5d.pyx", line 277, in h5py.h5d.DatasetID.set_extent
ValueError: 

RESIZE!
(15,)
RESIZE!
(15,)
RESIZE!
(15,)
RESIZE!
(15,)
RESIZE!
(15,)
RESIZE!
(15,)
RESIZE!
(15,)
RESIZE!
(15,)
RESIZE!
(11,)
RESIZE!
(11,)
RESIZE!
(11,)
RESIZE!
(11,)
RESIZE!
(11,)
RESIZE!
(11,)
RESIZE!
(11,)
RESIZE!
(11,)
RESIZE!
(13,)
RESIZE!
(13,)
RESIZE!
(13,)
RESIZE!
(13,)
RESIZE!
(13,)
RESIZE!
(13,)
RESIZE!
(13,)
RESIZE!
(13,)
RESIZE!
(13,)
RESIZE!
(13,)
RESIZE!
(13,)
RESIZE!
(13,)
RESIZE!
(13,)
RESIZE!
(13,)
RESIZE!
(13,)
RESIZE!
(13,)


In [None]:
num_cycles = 10
iterations = 5

loop = Loop(plant, controller, "GymCartPole", sart_folder, render=True)

for cycle in range(num_cycles):
    loop.run(5)

    batch.append_from_hdf5(sart_folder)
    print(f"Current batch size: {batch.num_episodes}")

    controller.fit(
        batch,
        iterations=50,
        epochs=30,
        minibatch_size=32,
        gamma=1.0,
        verbose=1,
    )

Exception ignored in: <function ExpandableDataset.__del__ at 0x305ba5f70>
Traceback (most recent call last):
  File "/Users/slange/Code/sabbatical/psiori/psipy-public/psipy/rl/io/sart.py", line 118, in __del__
    self.finalize()
  File "/Users/slange/Code/sabbatical/psiori/psipy-public/psipy/rl/io/sart.py", line 153, in finalize
    self._resize(self.rows)
  File "/Users/slange/Code/sabbatical/psiori/psipy-public/psipy/rl/io/sart.py", line 127, in _resize
    raise e
  File "/Users/slange/Code/sabbatical/psiori/psipy-public/psipy/rl/io/sart.py", line 124, in _resize
    self.dataset.resize((rows, *self.incoming_shape))
  File "/Users/slange/Code/sabbatical/psiori/.venv/lib/python3.8/site-packages/h5py/_hl/dataset.py", line 659, in resize
    self.id.set_extent(size)
  File "h5py/_objects.pyx", line 54, in h5py._objects.with_phil.wrapper
  File "h5py/_objects.pyx", line 55, in h5py._objects.with_phil.wrapper
  File "h5py/h5d.pyx", line 277, in h5py.h5d.DatasetID.set_extent
ValueError: 

RESIZE!
(13,)
RESIZE!
(13,)
RESIZE!
(13,)
RESIZE!
(13,)
RESIZE!
(13,)
RESIZE!
(13,)
RESIZE!
(13,)
RESIZE!
(13,)
INIT ACTION ABSTRACT BASE CLASS
{'move': 0.0}
DICT
('move',)
dict_keys(['move'])
ACTION:
{'move': 0.0}
INIT ACTION ABSTRACT BASE CLASS
{'move': 0}
DICT
('move',)
dict_keys(['move'])
PATH:
'state/values/cart_position'
PATH:
'state/values/cart_velocity'
PATH:
'state/values/pole_angle'
PATH:
'state/values/pole_velocity'
PATH:
'state/values/move_ACT'
INIT ACTION ABSTRACT BASE CLASS
{'move': 0.0}
DICT
('move',)
dict_keys(['move'])
ACTION:
{'move': 0.0}
INIT ACTION ABSTRACT BASE CLASS
{'move': 0}
DICT
('move',)
dict_keys(['move'])
PATH:
'state/cost'
PATH:
'state/terminal'
PATH:
'action/move'
INIT ACTION ABSTRACT BASE CLASS
{'move': 0.0}
DICT
('move',)
dict_keys(['move'])
ACTION:
{'move': 0.0}
INIT ACTION ABSTRACT BASE CLASS
{'move': 0}
DICT
('move',)
dict_keys(['move'])
INIT ACTION ABSTRACT BASE CLASS
{'move': 0.0}
DICT
('move',)
dict_keys(['move'])
ACTION:
{'move': 0.0}
INIT ACTIO

Exception ignored in: <function ExpandableDataset.__del__ at 0x305ba5f70>
Traceback (most recent call last):
  File "/Users/slange/Code/sabbatical/psiori/psipy-public/psipy/rl/io/sart.py", line 118, in __del__
    self.finalize()
  File "/Users/slange/Code/sabbatical/psiori/psipy-public/psipy/rl/io/sart.py", line 153, in finalize
    self._resize(self.rows)
  File "/Users/slange/Code/sabbatical/psiori/psipy-public/psipy/rl/io/sart.py", line 127, in _resize
    raise e
  File "/Users/slange/Code/sabbatical/psiori/psipy-public/psipy/rl/io/sart.py", line 124, in _resize
    self.dataset.resize((rows, *self.incoming_shape))
  File "/Users/slange/Code/sabbatical/psiori/.venv/lib/python3.8/site-packages/h5py/_hl/dataset.py", line 659, in resize
    self.id.set_extent(size)
  File "h5py/_objects.pyx", line 54, in h5py._objects.with_phil.wrapper
  File "h5py/_objects.pyx", line 55, in h5py._objects.with_phil.wrapper
  File "h5py/h5d.pyx", line 277, in h5py.h5d.DatasetID.set_extent
ValueError: 

RESIZE!
(15,)
RESIZE!
(15,)
RESIZE!
(15,)
RESIZE!
(15,)
RESIZE!
(15,)
RESIZE!
(15,)
RESIZE!
(15,)
RESIZE!
(15,)
RESIZE!
(12,)
RESIZE!
(12,)
RESIZE!
(12,)
RESIZE!
(12,)
RESIZE!
(12,)
RESIZE!
(12,)
RESIZE!
(12,)
RESIZE!
(12,)
RESIZE!
(14,)
RESIZE!
(14,)
RESIZE!
(14,)
RESIZE!
(14,)
RESIZE!
(14,)
RESIZE!
(14,)
RESIZE!
(14,)
RESIZE!
(14,)
RESIZE!
(11,)
RESIZE!
(11,)
RESIZE!
(11,)
RESIZE!
(11,)
RESIZE!
(11,)
RESIZE!
(11,)
RESIZE!
(11,)
RESIZE!
(11,)
Epoch 18/30
Epoch 19/30
Epoch 20/30
Epoch 21/30
Epoch 22/30
Epoch 23/30
Epoch 24/30
Epoch 25/30
Epoch 26/30
Epoch 27/30
Epoch 28/30
Epoch 29/30
Epoch 30/30
Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30
Epoch 7/30
Epoch 8/30
Epoch 9/30
Epoch 10/30
Epoch 11/30
Epoch 12/30
Epoch 13/30
Epoch 14/30
Epoch 15/30
Epoch 16/30
Epoch 17/30
Epoch 18/30
Epoch 19/30
Epoch 20/30
Epoch 21/30
Epoch 22/30
Epoch 23/30
Epoch 24/30
Epoch 25/30
Epoch 26/30
Epoch 27/30
Epoch 28/30
Epoch 29/30
Epoch 30/30
Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
E

Exception ignored in: <function ExpandableDataset.__del__ at 0x305ba5f70>
Traceback (most recent call last):
  File "/Users/slange/Code/sabbatical/psiori/psipy-public/psipy/rl/io/sart.py", line 118, in __del__
    self.finalize()
  File "/Users/slange/Code/sabbatical/psiori/psipy-public/psipy/rl/io/sart.py", line 153, in finalize
    self._resize(self.rows)
  File "/Users/slange/Code/sabbatical/psiori/psipy-public/psipy/rl/io/sart.py", line 127, in _resize
    raise e
  File "/Users/slange/Code/sabbatical/psiori/psipy-public/psipy/rl/io/sart.py", line 124, in _resize
    self.dataset.resize((rows, *self.incoming_shape))
  File "/Users/slange/Code/sabbatical/psiori/.venv/lib/python3.8/site-packages/h5py/_hl/dataset.py", line 659, in resize
    self.id.set_extent(size)
  File "h5py/_objects.pyx", line 54, in h5py._objects.with_phil.wrapper
  File "h5py/_objects.pyx", line 55, in h5py._objects.with_phil.wrapper
  File "h5py/h5d.pyx", line 277, in h5py.h5d.DatasetID.set_extent
ValueError: 

RESIZE!
(16,)
RESIZE!
(16,)


Exception ignored in: <function ExpandableDataset.__del__ at 0x305ba5f70>
Traceback (most recent call last):
  File "/Users/slange/Code/sabbatical/psiori/psipy-public/psipy/rl/io/sart.py", line 118, in __del__
    self.finalize()
  File "/Users/slange/Code/sabbatical/psiori/psipy-public/psipy/rl/io/sart.py", line 153, in finalize
    self._resize(self.rows)
  File "/Users/slange/Code/sabbatical/psiori/psipy-public/psipy/rl/io/sart.py", line 127, in _resize
    raise e
  File "/Users/slange/Code/sabbatical/psiori/psipy-public/psipy/rl/io/sart.py", line 124, in _resize
    self.dataset.resize((rows, *self.incoming_shape))
  File "/Users/slange/Code/sabbatical/psiori/.venv/lib/python3.8/site-packages/h5py/_hl/dataset.py", line 659, in resize
    self.id.set_extent(size)
  File "h5py/_objects.pyx", line 54, in h5py._objects.with_phil.wrapper
  File "h5py/_objects.pyx", line 55, in h5py._objects.with_phil.wrapper
  File "h5py/h5d.pyx", line 277, in h5py.h5d.DatasetID.set_extent
ValueError: 

RESIZE!
(16,)
RESIZE!
(16,)
RESIZE!
(16,)
RESIZE!
(16,)


Exception ignored in: <function ExpandableDataset.__del__ at 0x305ba5f70>
Traceback (most recent call last):
  File "/Users/slange/Code/sabbatical/psiori/psipy-public/psipy/rl/io/sart.py", line 118, in __del__
    self.finalize()
  File "/Users/slange/Code/sabbatical/psiori/psipy-public/psipy/rl/io/sart.py", line 153, in finalize
    self._resize(self.rows)
  File "/Users/slange/Code/sabbatical/psiori/psipy-public/psipy/rl/io/sart.py", line 127, in _resize
    raise e
  File "/Users/slange/Code/sabbatical/psiori/psipy-public/psipy/rl/io/sart.py", line 124, in _resize
    self.dataset.resize((rows, *self.incoming_shape))
  File "/Users/slange/Code/sabbatical/psiori/.venv/lib/python3.8/site-packages/h5py/_hl/dataset.py", line 659, in resize
    self.id.set_extent(size)
  File "h5py/_objects.pyx", line 54, in h5py._objects.with_phil.wrapper
  File "h5py/_objects.pyx", line 55, in h5py._objects.with_phil.wrapper
  File "h5py/h5d.pyx", line 277, in h5py.h5d.DatasetID.set_extent
ValueError: 

RESIZE!
(16,)
INIT ACTION ABSTRACT BASE CLASS
{'move': 0.0}
DICT
('move',)
dict_keys(['move'])
ACTION:
{'move': 0.0}
INIT ACTION ABSTRACT BASE CLASS
{'move': 0}
DICT
('move',)
dict_keys(['move'])
PATH:
'state/values/cart_position'
PATH:
'state/values/cart_velocity'
INIT ACTION ABSTRACT BASE CLASS
{'move': 0.0}
DICT
('move',)
dict_keys(['move'])
ACTION:
{'move': 0.0}
INIT ACTION ABSTRACT BASE CLASS
{'move': 0}
DICT
('move',)
dict_keys(['move'])
PATH:
'state/values/pole_angle'
PATH:
'state/values/pole_velocity'
PATH:
'state/values/move_ACT'
PATH:
'state/cost'
PATH:
'state/terminal'
PATH:
'action/move'
INIT ACTION ABSTRACT BASE CLASS
{'move': 0.0}
DICT
('move',)
dict_keys(['move'])
ACTION:
{'move': 0.0}
INIT ACTION ABSTRACT BASE CLASS
{'move': 0}
DICT
('move',)
dict_keys(['move'])
INIT ACTION ABSTRACT BASE CLASS
{'move': 1.0}
DICT
('move',)
dict_keys(['move'])
ACTION:
{'move': 1.0}
INIT ACTION ABSTRACT BASE CLASS
{'move': 1}
DICT
('move',)
dict_keys(['move'])
INIT ACTION ABSTRACT BASE CLA

Exception ignored in: <function ExpandableDataset.__del__ at 0x305ba5f70>
Traceback (most recent call last):
  File "/Users/slange/Code/sabbatical/psiori/psipy-public/psipy/rl/io/sart.py", line 118, in __del__
    self.finalize()
  File "/Users/slange/Code/sabbatical/psiori/psipy-public/psipy/rl/io/sart.py", line 153, in finalize
    self._resize(self.rows)
  File "/Users/slange/Code/sabbatical/psiori/psipy-public/psipy/rl/io/sart.py", line 127, in _resize
    raise e
  File "/Users/slange/Code/sabbatical/psiori/psipy-public/psipy/rl/io/sart.py", line 124, in _resize
    self.dataset.resize((rows, *self.incoming_shape))
  File "/Users/slange/Code/sabbatical/psiori/.venv/lib/python3.8/site-packages/h5py/_hl/dataset.py", line 659, in resize
    self.id.set_extent(size)
  File "h5py/_objects.pyx", line 54, in h5py._objects.with_phil.wrapper
  File "h5py/_objects.pyx", line 55, in h5py._objects.with_phil.wrapper
  File "h5py/h5d.pyx", line 277, in h5py.h5d.DatasetID.set_extent
ValueError: 

RESIZE!
(47,)
RESIZE!
(47,)
RESIZE!
(47,)
RESIZE!
(47,)
RESIZE!
(47,)
RESIZE!
(47,)
RESIZE!
(47,)
RESIZE!
(47,)
RESIZE!
(51,)
RESIZE!
(51,)
RESIZE!
(51,)
RESIZE!
(51,)
RESIZE!
(51,)
RESIZE!
(51,)
RESIZE!
(51,)
RESIZE!
(51,)
RESIZE!
(25,)
RESIZE!
(25,)
RESIZE!
(25,)
RESIZE!
(25,)
RESIZE!
(25,)
RESIZE!
(25,)
RESIZE!
(25,)
RESIZE!
(25,)
RESIZE!
(38,)
RESIZE!
(38,)
RESIZE!
(38,)
RESIZE!
(38,)
RESIZE!
(38,)
RESIZE!
(38,)
RESIZE!
(38,)
RESIZE!
(38,)
Epoch 11/30
Epoch 12/30
Epoch 13/30
Epoch 14/30
Epoch 15/30
Epoch 16/30
Epoch 17/30
Epoch 18/30
Epoch 19/30
Epoch 20/30
Epoch 21/30
Epoch 22/30
Epoch 23/30
Epoch 24/30
Epoch 25/30
Epoch 26/30
Epoch 27/30
Epoch 28/30
Epoch 29/30
Epoch 30/30
Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30
Epoch 7/30
Epoch 8/30
Epoch 9/30
Epoch 10/30
Epoch 11/30
Epoch 12/30
Epoch 13/30
Epoch 14/30
Epoch 15/30
Epoch 16/30
Epoch 17/30
Epoch 18/30
Epoch 19/30
Epoch 20/30
Epoch 21/30
Epoch 22/30
Epoch 23/30
Epoch 24/30
Epoch 25/30
Epoch 26/30
Epoch 27/

Exception ignored in: <function ExpandableDataset.__del__ at 0x305ba5f70>
Traceback (most recent call last):
  File "/Users/slange/Code/sabbatical/psiori/psipy-public/psipy/rl/io/sart.py", line 118, in __del__
    self.finalize()
  File "/Users/slange/Code/sabbatical/psiori/psipy-public/psipy/rl/io/sart.py", line 153, in finalize
    self._resize(self.rows)
  File "/Users/slange/Code/sabbatical/psiori/psipy-public/psipy/rl/io/sart.py", line 127, in _resize
    raise e
  File "/Users/slange/Code/sabbatical/psiori/psipy-public/psipy/rl/io/sart.py", line 124, in _resize
    self.dataset.resize((rows, *self.incoming_shape))
  File "/Users/slange/Code/sabbatical/psiori/.venv/lib/python3.8/site-packages/h5py/_hl/dataset.py", line 659, in resize
    self.id.set_extent(size)
  File "h5py/_objects.pyx", line 54, in h5py._objects.with_phil.wrapper
  File "h5py/_objects.pyx", line 55, in h5py._objects.with_phil.wrapper
  File "h5py/h5d.pyx", line 277, in h5py.h5d.DatasetID.set_extent
ValueError: 

RESIZE!
(39,)
RESIZE!
(39,)
RESIZE!
(39,)
RESIZE!
(39,)
RESIZE!
(39,)


Exception ignored in: <function ExpandableDataset.__del__ at 0x305ba5f70>
Traceback (most recent call last):
  File "/Users/slange/Code/sabbatical/psiori/psipy-public/psipy/rl/io/sart.py", line 118, in __del__
    self.finalize()
  File "/Users/slange/Code/sabbatical/psiori/psipy-public/psipy/rl/io/sart.py", line 153, in finalize
    self._resize(self.rows)
  File "/Users/slange/Code/sabbatical/psiori/psipy-public/psipy/rl/io/sart.py", line 127, in _resize
    raise e
  File "/Users/slange/Code/sabbatical/psiori/psipy-public/psipy/rl/io/sart.py", line 124, in _resize
    self.dataset.resize((rows, *self.incoming_shape))
  File "/Users/slange/Code/sabbatical/psiori/.venv/lib/python3.8/site-packages/h5py/_hl/dataset.py", line 659, in resize
    self.id.set_extent(size)
  File "h5py/_objects.pyx", line 54, in h5py._objects.with_phil.wrapper
  File "h5py/_objects.pyx", line 55, in h5py._objects.with_phil.wrapper
  File "h5py/h5d.pyx", line 277, in h5py.h5d.DatasetID.set_extent
ValueError: 

RESIZE!
(39,)
RESIZE!
(39,)
RESIZE!
(39,)
INIT ACTION ABSTRACT BASE CLASS
{'move': 0.0}
DICT
('move',)
dict_keys(['move'])
ACTION:
{'move': 0.0}
INIT ACTION ABSTRACT BASE CLASS
{'move': 0}
DICT
('move',)
dict_keys(['move'])
PATH:
'state/values/cart_position'
PATH:
'state/values/cart_velocity'
INIT ACTION ABSTRACT BASE CLASS
{'move': 0.0}
DICT
('move',)
dict_keys(['move'])
ACTION:
{'move': 0.0}
INIT ACTION ABSTRACT BASE CLASS
{'move': 0}
DICT
('move',)
dict_keys(['move'])
PATH:
'state/values/pole_angle'
PATH:
'state/values/pole_velocity'
PATH:
'state/values/move_ACT'
PATH:
'state/cost'
PATH:
'state/terminal'
PATH:
'action/move'
INIT ACTION ABSTRACT BASE CLASS
{'move': 0.0}
DICT
('move',)
dict_keys(['move'])
ACTION:
{'move': 0.0}
INIT ACTION ABSTRACT BASE CLASS
{'move': 0}
DICT
('move',)
dict_keys(['move'])
INIT ACTION ABSTRACT BASE CLASS
{'move': 0.0}
DICT
('move',)
dict_keys(['move'])
ACTION:
{'move': 0.0}
INIT ACTION ABSTRACT BASE CLASS
{'move': 0}
DICT
('move',)
dict_keys(['move'])
I

Exception ignored in: <function ExpandableDataset.__del__ at 0x305ba5f70>
Traceback (most recent call last):
  File "/Users/slange/Code/sabbatical/psiori/psipy-public/psipy/rl/io/sart.py", line 118, in __del__
    self.finalize()
  File "/Users/slange/Code/sabbatical/psiori/psipy-public/psipy/rl/io/sart.py", line 153, in finalize
    self._resize(self.rows)
  File "/Users/slange/Code/sabbatical/psiori/psipy-public/psipy/rl/io/sart.py", line 127, in _resize
    raise e
  File "/Users/slange/Code/sabbatical/psiori/psipy-public/psipy/rl/io/sart.py", line 124, in _resize
    self.dataset.resize((rows, *self.incoming_shape))
  File "/Users/slange/Code/sabbatical/psiori/.venv/lib/python3.8/site-packages/h5py/_hl/dataset.py", line 659, in resize
    self.id.set_extent(size)
  File "h5py/_objects.pyx", line 54, in h5py._objects.with_phil.wrapper
  File "h5py/_objects.pyx", line 55, in h5py._objects.with_phil.wrapper
  File "h5py/h5d.pyx", line 277, in h5py.h5d.DatasetID.set_extent
ValueError: 

INIT ACTION ABSTRACT BASE CLASS
{'move': 1.0}
DICT
('move',)
dict_keys(['move'])
ACTION:
{'move': 1.0}
INIT ACTION ABSTRACT BASE CLASS
{'move': 1}
DICT
('move',)
dict_keys(['move'])
INIT ACTION ABSTRACT BASE CLASS
{'move': 1.0}
DICT
('move',)
dict_keys(['move'])
ACTION:
{'move': 1.0}
INIT ACTION ABSTRACT BASE CLASS
{'move': 1}
DICT
('move',)
dict_keys(['move'])
INIT ACTION ABSTRACT BASE CLASS
{'move': 0.0}
DICT
('move',)
dict_keys(['move'])
ACTION:
{'move': 0.0}
INIT ACTION ABSTRACT BASE CLASS
{'move': 0}
DICT
('move',)
dict_keys(['move'])
INIT ACTION ABSTRACT BASE CLASS
{'move': 1.0}
DICT
('move',)
dict_keys(['move'])
ACTION:
{'move': 1.0}
INIT ACTION ABSTRACT BASE CLASS
{'move': 1}
DICT
('move',)
dict_keys(['move'])
INIT ACTION ABSTRACT BASE CLASS
{'move': 1.0}
DICT
('move',)
dict_keys(['move'])
ACTION:
{'move': 1.0}
INIT ACTION ABSTRACT BASE CLASS
{'move': 1}
DICT
('move',)
dict_keys(['move'])
INIT ACTION ABSTRACT BASE CLASS
{'move': 0.0}
DICT
('move',)
dict_keys(['move'])
ACTION:
{

Exception ignored in: <function ExpandableDataset.__del__ at 0x305ba5f70>
Traceback (most recent call last):
  File "/Users/slange/Code/sabbatical/psiori/psipy-public/psipy/rl/io/sart.py", line 118, in __del__
    self.finalize()
  File "/Users/slange/Code/sabbatical/psiori/psipy-public/psipy/rl/io/sart.py", line 153, in finalize
    self._resize(self.rows)
  File "/Users/slange/Code/sabbatical/psiori/psipy-public/psipy/rl/io/sart.py", line 127, in _resize
    raise e
  File "/Users/slange/Code/sabbatical/psiori/psipy-public/psipy/rl/io/sart.py", line 124, in _resize
    self.dataset.resize((rows, *self.incoming_shape))
  File "/Users/slange/Code/sabbatical/psiori/.venv/lib/python3.8/site-packages/h5py/_hl/dataset.py", line 659, in resize
    self.id.set_extent(size)
  File "h5py/_objects.pyx", line 54, in h5py._objects.with_phil.wrapper
  File "h5py/_objects.pyx", line 55, in h5py._objects.with_phil.wrapper
  File "h5py/h5d.pyx", line 277, in h5py.h5d.DatasetID.set_extent
ValueError: 

INIT ACTION ABSTRACT BASE CLASS
{'move': 1.0}
DICT
('move',)
dict_keys(['move'])
ACTION:
{'move': 1.0}
INIT ACTION ABSTRACT BASE CLASS
{'move': 1}
DICT
('move',)
dict_keys(['move'])
INIT ACTION ABSTRACT BASE CLASS
{'move': 1.0}
DICT
('move',)
dict_keys(['move'])
ACTION:
{'move': 1.0}
INIT ACTION ABSTRACT BASE CLASS
{'move': 1}
DICT
('move',)
dict_keys(['move'])
INIT ACTION ABSTRACT BASE CLASS
{'move': 0.0}
DICT
('move',)
dict_keys(['move'])
ACTION:
{'move': 0.0}
INIT ACTION ABSTRACT BASE CLASS
{'move': 0}
DICT
('move',)
dict_keys(['move'])
INIT ACTION ABSTRACT BASE CLASS
{'move': 1.0}
DICT
('move',)
dict_keys(['move'])
ACTION:
{'move': 1.0}
INIT ACTION ABSTRACT BASE CLASS
{'move': 1}
DICT
('move',)
dict_keys(['move'])
RESIZE!
(301,)
RESIZE!
(301,)
RESIZE!
(301,)
RESIZE!
(301,)
RESIZE!
(301,)
RESIZE!
(301,)
RESIZE!
(301,)
RESIZE!
(301,)
RESIZE!
(163,)
RESIZE!
(163,)
RESIZE!
(163,)
RESIZE!
(163,)
RESIZE!
(163,)
RESIZE!
(163,)
RESIZE!
(163,)
RESIZE!
(163,)
INIT ACTION ABSTRACT BASE CLASS


## Growing Batch Tutorial
We will use the `CartPoleSway` env here, since `CartPole-v0` has a hardcoded discrete action space.

As with the Batch tutorial, we will lay out all of our learning components for ease of use and reading.  Here, we define some placeholder variables for our plant, action, state, and lookback (how many observations are in a stack to create a state).

In [None]:
from psipy.rl.plant.gym.cartpole_plants import (
    CartPoleSwayContAction,
    CartPoleSwayContinuousPlant,
    CartPoleSwayState,
)

plant = CartPoleSwayContinuousPlant()  # Note that it is instantiated!
action = CartPoleSwayContAction
state = CartPoleSwayState
lookback = 1
sart_folder = "tutorial-growing-sart"

We want to use NFQCA to control this plant.  Therefore, we need a two neural network internal models, one for the actor and one for the critic.  Below we use functions to create the models, but this is not necessary.  Note that we need to use the lookback to properly shape our neural networks' inputs.

In [None]:
from psipy.rl.control.nfqca import NFQCA

def make_actor(inputs, lookback):
    inp = tfkl.Input((inputs, lookback), name="state_actor")
    net = tfkl.Flatten()(inp)
    net = tfkl.Dense(40, activation="tanh")(net)
    net = tfkl.Dense(40, activation="tanh")(net)
    net = tfkl.Dense(1, activation="tanh")(net)
    return tf.keras.Model(inp, net, name="actor")


def make_critic(inputs, lookback):
    inp = tfkl.Input((inputs, lookback), name="state_critic")
    act = tfkl.Input((1,), name="act_in")
    net = tfkl.Concatenate()([tfkl.Flatten()(inp), tfkl.Flatten()(act)])
    net = tfkl.Dense(40, activation="tanh")(net)
    net = tfkl.Dense(40, activation="tanh")(net)
    net = tfkl.Dense(1, activation="sigmoid")(net)
    return tf.keras.Model([inp, act], net, name="critic")

actor = make_actor(len(state.channels()), lookback)
critic = make_critic(len(state.channels()), lookback)

controller = NFQCA(
    actor=actor, 
    critic=critic, 
    state_channels=state.channels(), 
    action=CartPoleSwayContAction,
    lookback=lookback
)

Now we create the `Loop`, which takes a name (the name in the SART logs), a plant, a controller, and a path to save the SART logs.

In [None]:
loop = Loop(plant, controller, "CartPoleSway", sart_folder)

Now we will use NFQCA to do the initial exploration, as well as all data collection.  We first collect some initial data outside the loop and then collect more data within the growing-batch-loop.  We print some extra things so you can see what is going on, but that is unnecessary.

Note that depending on the internal cost function of the `CartPoleSwayEnv` at this time, NFQCA might learn nothing.  This tutorial just shows the general form of training NFQCA.

In [None]:
num_cycles = 3
iterations = 2

loop.run(10)
batch = Batch.from_hdf5(sart_folder, lookback=lookback, control=controller)

for cycle in range(num_cycles):
    LOG.info("Cycle: %d", cycle + 1)
    print(f"Current batch size: {batch.num_episodes}")
    # Fit the normalizer on the data. Fitting iteratively makes the fit 
    # parameters hone in on the true population parameters 
    # (See Batch Tutorial above for more detail on how normalization works)
    controller.fit_normalizer(batch.observations, method="meanstd")

    # NFQCA does not have a generic fit method
    for iteration in range(iterations):
        LOG.info("NFQCA Iteration: %d", iteration + 1)
        controller.fit_critic(batch,
                         iterations=1,
                         epochs=10,
                         minibatch_size=-1,
                         gamma=1.0,
                         verbose=0)
        controller.fit_actor(batch,
                        epochs=10,
                        minibatch_size=-1,
                        verbose=0)

    loop = Loop(plant, controller, "GrowingBatch", sart_folder)
    loop.run(5)
    # Batch.append_from_hdf5() appends any new files found in the folder
    batch.append_from_hdf5(sart_folder)

Now we can see how the model performs live.

In [None]:
loop = Loop(plant, controller, "GrowingBatchEval", f"live-{sart_folder}", render=True)
loop.run(5)

## Advanced Learning Tutorial

You've made it to the next level! Congratulations.  Now it is time to delve deeper into the power of offline reinforcement learning.  The general guidelines training will not be outlined here.

Let's train `CartPoleSway` again, but this time alter the cost function and create fake transitions to aid the cart to its goal.

### Problem Setting
We want the cart to move the position 0 (middle of the screen) from any starting position.  We do not care about the pole.

For this, we will generate a cost function that only deals out cost based on position, and create a fake transition set that shows 0 cost at the goal position.

*Note the imports inside the functions: this is bad practice but I do it here to show what imports we need.*

In [None]:
def create_fake_episodes(sart_path:str, lookback:int): # ->List[Episode]
    """Create a fake episode at position 0 for every episode already collected
    
    Since the cart probably was never at position 0 exactly, we add a set of Episodes with this transition.
    """
    import glob
    from psipy.rl.io.sart import SARTReader
    from psipy.rl.io.batch import Episode
    
    more_episodes = []

    for path in glob.glob(f"{sart_path}/*.hdf5"):
        with SARTReader(path) as reader:
            o, a, t, c = reader.load_full_episode()

            # Add episode full of goal states
            o = o.copy()
            a = a.copy()
            o[:, 0] = 0
            o[:, 1] = 0
            # A swinging pole affects the position of the cart,
            # so we say no swing here as well
            o[:, 2] = 180
            o[:, 3] = 0
            a[:] = 0  # The cart should not move once in this position
            more_episodes.append(Episode(o, a, t, c, lookback=lookback))
            
    return more_episodes

def costfunc(states:np.ndarray): # -> np.ndarray
    """Recalculate costs on all states provided
    
    This calculates costs on multiple states, so it returns an array
    """
    from psipy.rl.control.nfq import tanh2
    # Position is already defined against 0 (relative)
    position = states[:, 0]
    cost = tanh2(position, C=.2, mu=.1)
    
    return cost

Let's set everything up until we need to add the fake transitions.

In [None]:
from psipy.rl.plant.gym.cartpole_plants import (
    CartPoleSwayContAction,
    CartPoleSwayContinuousPlant,
    CartPoleSwayState,
)
from psipy.rl.control.nfqca import NFQCA


plant = CartPoleSwayContinuousPlant()  # Note that it is instantiated!
action = CartPoleSwayContAction
state = CartPoleSwayState
lookback = 5
sart_folder = "tutorial-advanced-sart"


def make_actor(inputs, lookback):
    inp = tfkl.Input((inputs, lookback), name="state_actor")
    net = tfkl.Flatten()(inp)
    net = tfkl.Dense(40, activation="tanh")(net)
    net = tfkl.Dense(40, activation="tanh")(net)
    net = tfkl.Dense(1, activation="tanh")(net)
    return tf.keras.Model(inp, net, name="actor")


def make_critic(inputs, lookback):
    inp = tfkl.Input((inputs, lookback), name="state_critic")
    act = tfkl.Input((1,), name="act_in")
    net = tfkl.Concatenate()([tfkl.Flatten()(inp), tfkl.Flatten()(act)])
    net = tfkl.Dense(40, activation="tanh")(net)
    net = tfkl.Dense(40, activation="tanh")(net)
    net = tfkl.Dense(1, activation="sigmoid")(net)
    return tf.keras.Model([inp, act], net, name="critic")

actor = make_actor(len(state.channels()), lookback)
critic = make_critic(len(state.channels()), lookback)

controller = NFQCA(
    actor=actor, 
    critic=critic, 
    state_channels=state.channels(), 
    action=CartPoleSwayContAction,
    lookback=lookback
)

loop = Loop(plant, controller, "CartPolePosition", sart_folder)

The critic recieves the cost function, and before we start training we want to add our fake goal position transitions.  Therefore, we put `costfunc` in `fit_critic` and append fake episodes before we start the fitting cycles.

We will print the size of the batch so it can be seen how the fake episodes increase the batch episode count.

In [None]:
num_cycles = 3
iterations = 3

loop.run(10)
batch = Batch.from_hdf5(sart_folder, lookback=lookback, control=controller)
print(f"Current batch size: {batch.num_episodes}")
# We now append the fake created episodes
# We could also throw away the created batch and only have the episodes
# created in the function by doing:
# batch = Batch(create_fake_episodes(sart_folder, lookback))
batch = batch.append(create_fake_episodes(sart_folder, lookback))
        
for cycle in range(num_cycles):
    LOG.info("Cycle: %d", cycle + 1)
    print(f"Current batch size: {batch.num_episodes}")
    controller.fit_normalizer(batch.observations, method="meanstd")

    for iteration in range(iterations):
        LOG.info("NFQCA Iteration: %d", iteration + 1)
        controller.fit_critic(batch,
                         iterations=1,
                         # We add the cost function here
                         costfunc=costfunc,
                         epochs=10,
                         minibatch_size=-1,
                         gamma=1.0,
                         verbose=0)
        controller.fit_actor(batch,
                        epochs=10,
                        minibatch_size=-1,
                        verbose=0)

    loop = Loop(plant, controller, "GrowingBatch", sart_folder)
    loop.run(10)
    batch.append_from_hdf5(sart_folder)

Let's see if the controller moves to the goal position.  Oh the tension!

In [None]:
loop = Loop(plant, controller, "GrowingBatchEval", f"live-{sart_folder}", render=True)
loop.run(3)

Congratulations, you've made it through the Advanced Tutorial.  You are now an expert!

## Please delete the SART logs created by this tutorial if you no longer need them.  Or run the cell below to do that automatically.

In [None]:
import shutil
dirs = ["tutorial-advanced-sart", 
        "tutorial-batch-sart", 
        "tutorial-growing-batch", 
        "live-tutorial-growing-sart", 
        "live-tutorial-batch-sart", 
        "live-tutorial-advanced-sart"]
for d in dirs:
    try:
        shutil.rmtree(d)
    except Exception as e:
        print(e)