# Batch Learning Tutorials
This tutorial will explain how to train a controller to perform the batch and growing batch learning tasks.

We will be using our simulated Cartpole for this Batch tutorial.

In [None]:
import logging

import numpy as np

import tensorflow as tf
from tensorflow.keras import layers as tfkl

from psipy.rl.loop import Loop
from psipy.rl.core.controller import DiscreteRandomActionController, ContinuousRandomActionController
from psipy.rl.io.batch import Batch, Episode

LOG = logging.getLogger("psipy")

The following cell will test for the correct installation of the CartPole and pygame. You should see a pygame window popup and some (random) trajectories on the cartpole plant.

In [None]:
from psipy.rl.plants.simulated.cartpole import (
    CartPole,
    CartPoleState,
    CartPoleBangAction)

plant = CartPole(start_angle=0,          # start upright (default: start hanging down)
                 valid_angle=np.pi / 6)  # terminate episode and reset if pole falls below this angle.

rc = DiscreteRandomActionController(state_channels=CartPoleBangAction.channels,
                                    action=CartPoleBangAction)

plant.notify_episode_starts()
state = plant.check_initial_state(None)

for _ in range(500):
    action = rc.get_action(state)
    state = plant.get_next_state(state, action)
    plant.render()

    if state.terminal:
        plant.notify_episode_stops()
        plant.reset()
        plant.notify_episode_starts()
        state = plant.check_initial_state(None)

## Background
We use two learning paradigms, depending on the application and situation: batch and growing batch.
### Batch
In the batch paradigm, data is collected through some means, either a random controller, human, or otherwise.  A controller is trained on this data until convergence, and applied to the plant to test final performance.  In simple tasks, this should be sufficient, especially if the goal state or orther desirable regions of the state space can be reached by a random policy.
<img src="batch-paradigm.png">
### Growing Batch
In the growing batch paradigm, initial data can either be collected the same way as in the batch approach above. A controller is then trained on the intial batch for some iterations of the algorithm, and this controller is then used to collect *more* data from the plant.  The 'batch of data' grows, and the controller is trained on this bigger batch.  In this way, the controller essentially explores the state space itself, and can 'learn from its mistakes'.
<img src="growing-paradigm.png">

## Batch Tutorial
It is best practice to lay out all your learning components for ease of use and reading.  Here, we define some placeholder variables for our plant, action, state, and lookback (how many observations are in a stack to create a state).

In [None]:
state_type = CartPoleState
action_type= CartPoleBangAction

lookback = 1   # no history because we have no latencies in the simulation
sart_folder = "psidata-tutorial-batch-sart"

We want to use NFQ to control this plant.  Therefore, we need a neural network internal model that we create below. We also select a subset of those channels that we want to use for the state_representaiton. Please note, these can be, and in the this case, are different from the channels used in the Cartpole internally. We use (sin, cos) to represent the pole angle, not the pole angle directly.

In [None]:
from psipy.rl.controllers.nfq import NFQ

state_channels = [
    "cart_position",
    "cart_velocity",
    "pole_sine",
    "pole_cosine",
    "pole_velocity"]

n_inputs = len(state_channels)
n_outputs = len(action.legal_values[0])

def make_model(n_inputs, n_outputs, lookback):
    inp = tfkl.Input((n_inputs, lookback), name="state")
    net = tfkl.Flatten()(inp)
    net = tfkl.Dense(256, activation="relu")(net)
    net = tfkl.Dense(256, activation="relu")(net)
    net = tfkl.Dense(100, activation="tanh")(net)
    net = tfkl.Dense(n_outputs, activation="sigmoid")(net)
    model = tf.keras.Model(inp, net)
    model.summary()
    return model

model = make_model(n_inputs=n_inputs, n_outputs=n_outputs, 
                   lookback=lookback)

# Create the controller with our model.  
controller = NFQ(model=model,
                 state_channels=state_channels, 
                 action=action_type)



We also want a controller to explore.  We could use NFQ to explore, since with randomized weights it essentially acts as a random controller, but we will explicitly use a random controller here to demonstrate how to use different controllers at the same time. Since `CartPoleAction` is discrete, we use a `DiscreteRandomActionController`.

In [None]:
explorer = DiscreteRandomActionController(state_channels=state.channels(), action=action_type)

Now we create the `Loop`, which takes a name (the name in the SART logs), a plant, a controller, and a path to save the SART logs.

In [None]:
loop = Loop(plant, explorer, "CartPole", sart_folder, render=False)

Let's now collect some data with our explorer, the `DiscreteRandomActionController`. We want to collect 100 episodes.
**Attention:** Please ignore possible "ValueErrors" thrown by hdf5. 

In [None]:
loop.run(100)

Be aware that if you run this notebook or this cell multiple times, old data collected from previous runs will also be loaded unless you have deleted the SART folder. 

We now load the data into a `Batch` from the hdf5 files we just created.  Be aware that you have to set the lookback as well as the used state channels here! We also have to load the action's index (move_index) not its actual value (move), as this variant of NFQ uses a DQN-like network topology where for each of the possible actions, it has one output neuron that will be trained on that particular actions q-value. Thus, we need the index of the selected action, that corresponds to the index of the output neuron.

In [None]:
batch = Batch.from_hdf5(sart_folder, lookback=lookback, state_channels=state_channels, control=controller, action_channels=["move_index",])

Also note the logs: at the bottom, the logs will always tell you how many episodes were loaded.  If you want to know at any other point how many episodes the batch has loaded, check the `num_episodes` property.

In [None]:
batch.num_episodes

Neural networks like normalized data, so we fit NFQ's normalizer on the observations in the batch.  We have to pass in the batch's observations to fit the normalizer.

In [None]:
controller.fit_normalizer(batch.observations)


Now it is time for fitting.  We pass in the batch, and train.  This will take a couple minutes (not too long).

In [None]:
controller.fit(
    batch,
    iterations=20,
    epochs=2,
    minibatch_size=50,
    gamma=0.98,
    verbose=1,
)

Hurray! 

Sort of.

Now we can see how our trained controller fairs live in the plant.  Let's run the controller again by creating a new `Loop`, but this time not store the data in our SART folder (otherwise if we want to train again we will train on this data as well, i.e. growing batch).  We do this by changing the `logdir` param to something different from our SART folder.  I prefer prepending "live-" to the SART folder name.  Finally, we render the environment so we can see what is happening.  Enjoy!

**It's likely your controller will still fail and not balance the pole indefinetely.** You could try and collect more data and run NFQ for more iterations. Or you could read ahead how to make this process more (data) efficient.

*Note: the environment will not close on its own. We know of this issue and won't fix it (yet) :D*

In [None]:
eval_loop = Loop(plant, controller, "CartPoleEval", f"{sart_folder}-evaluation", render=True)
eval_loop.run(10)

### Growing Batch Example

The following will grow a batch of data interleaved with improving the model. Note, that we now use the controller instead of the explorer in the loop. We use the same sart_folder, adding additional data to the already existing data, which has been random collected. More data is never bad (as we use an off-policy method), no reason to start from fresh.

This training will take a while due to the network training inbetween the collection of episodes of interaction with the cartople. It should learn to balance the pole within 20 cycles of the outer loop or less. If it's not perfect, just execute the cell again.

In [None]:
loop = Loop(plant, controller, "CartPole", sart_folder, render=True)
controller.epsilon = 0.1   # try a random action every 20 steps for continuous exploration of alternatives.

for cycle in range(20):
    print(f"Cycle: {cycle}")

    # execute two explorative runs
    loop.run(2, max_episode_steps=400)

    # load the grown batch
    batch = Batch.from_hdf5(sart_folder,
                            lookback=lookback,
                            state_channels=state_channels,
                            control=controller,
                            action_channels=["move_index",])
    
    # fit the controller for another few iterations
    controller.fit(
        batch,
        iterations=4,
        epochs=8,
        minibatch_size=500,
        gamma=0.98
    )

In [None]:
num_cycles = 10

loop = Loop(plant, controller, "GymCartPole", sart_folder, render=True)

for cycle in range(num_cycles):
    loop.run(5)

    batch.append_from_hdf5(sart_folder, action_channels=["move_index",])
    print(f"Current batch size: {batch.num_episodes}")

    controller.fit(
        batch,
        iterations=20,
        epochs=25,
        minibatch_size=64,
        gamma=0.99,
        verbose=1,
    )

**Attention:** everything below has not yet been adapted to psipy-public and will not work right now. WIP.

## Growing Batch Tutorial with Continuous Actions

We will use the the same plant here, but let it start with the pole hanging, trying to learn to swing it up and balance it. We will also use NFQ-CA, that is an actor-critic variant of NFQ with continuous actions.

As with the Batch tutorial, we will lay out all of our learning components for ease of use and reading.  Here, we define some placeholder variables for our plant, action, state, and lookback (how many observations are in a stack to create a state). Some is copied from above, just for clarity.

In [None]:
from psipy.rl.plants.simulated.cartpole import (
    CartPoleContinuousAction,
)
action_type = CartPoleContinuousAction
state_type = CartPoleState

lookback = 1
sart_folder = "psidata-growing-sart"

plant = CartPole(
    action_type=CartPoleContinuousAction
)  # Note that it is instantiated!

state_channels = [
    "cart_position",
    "cart_velocity",
    "pole_sine",
    "pole_cosine",
    "pole_velocity"]

n_inputs = len(state_channels)

We want to use NFQCA to control this plant.  Therefore, we need a two neural network internal models, one for the actor and one for the critic.  Below we use functions to create the models, but this is not necessary.  Note that we need to use the lookback to properly shape our neural networks' inputs.

In [None]:
from psipy.rl.controllers.nfqca import NFQCA
from psipy.rl.controllers.noise import RandomNormalNoise


def make_actor(inputs, lookback):
    inp = tfkl.Input((inputs, lookback), name="state_actor")
    net = tfkl.Flatten()(inp)
    net = tfkl.Dense(256, activation="relu")(net)
    net = tfkl.Dense(256, activation="relu")(net)
    net = tfkl.Dense(100, activation="tanh")(net)
    net = tfkl.Dense(1, activation="tanh")(net)
    return tf.keras.Model(inp, net, name="actor")


def make_critic(inputs, lookback):
    inp = tfkl.Input((inputs, lookback), name="state_critic")
    act = tfkl.Input((1,), name="act_in")
    net = tfkl.Concatenate()([tfkl.Flatten()(inp), tfkl.Flatten()(act)])
    net = tfkl.Dense(256, activation="relu")(net)
    net = tfkl.Dense(256, activation="relu")(net)
    net = tfkl.Dense(100, activation="tanh")(net)
    net = tfkl.Dense(1, activation="sigmoid")(net)
    return tf.keras.Model([inp, act], net, name="critic")

actor = make_actor(n_inputs, lookback)
critic = make_critic(n_inputs, lookback)

controller = NFQCA(
    actor=actor, 
    critic=critic, 
    state_channels=state_channels, 
    action=action_type,
    lookback=lookback,
)

Now we create the `Loop`, which takes a name (the name in the SART logs), a plant, a controller, and a path to save the SART logs.

In [None]:
loop = Loop(plant, controller, "CartpoleSwingup", sart_folder, render=True)

Now we will use NFQCA to do the initial exploration, as well as all data collection.  We first collect some initial data outside the loop and then collect more data within the growing-batch-loop.  We print some extra things so you can see what is going on, but that is unnecessary.

Note that depending on the internal cost function of the `CartPoleSwayEnv` at this time, NFQCA might learn nothing.  This tutorial just shows the general form of training NFQCA.

In [None]:
episodes = 400
steps = 400
cycles = 2

controller.exploration=RandomNormalNoise(size=1, std=1.0)  # this achieves exploration by adding noise to all actions


for episode in range(episodes):

    loop.run(1, max_episode_steps=steps)
    
    batch = Batch.from_hdf5(sart_folder, lookback=lookback, control=controller, state_channels=state_channels)

    for cycle in range(cycles):

        print(f"Iteraton {episode}/{cycle} with batch size: {batch.num_episodes}")

        if episode < 20 or (episode < episodes / 2 and episode % 20 == 0):
            controller.fit_normalizer(batch.observations, method="meanstd")

        # NFQCA does not have a generic fit method
        controller.fit_critic(
            batch,
            iterations=2,
            epochs=8,
            minibatch_size=8192,
            gamma=0.98,
            verbose=0)
        
        controller.fit_actor(
            batch,
            epochs=1,
            minibatch_size=2048,
            verbose=0)


Now we can see how the model performs live.

In [None]:
loop = Loop(plant, controller, "GrowingBatchEval", f"live-{sart_folder}", render=True)
loop.run(5)

## Advanced Learning Tutorial

You've made it to the next level! Congratulations.  Now it is time to delve deeper into the power of offline reinforcement learning.  The general guidelines training will not be outlined here.

Let's train `CartPoleSway` again, but this time alter the cost function and create fake transitions to aid the cart to its goal.

### Problem Setting
We want the cart to move the position 0 (middle of the screen) from any starting position.  We do not care about the pole.

For this, we will generate a cost function that only deals out cost based on position, and create a fake transition set that shows 0 cost at the goal position.

*Note the imports inside the functions: this is bad practice but I do it here to show what imports we need.*

In [None]:
def create_fake_episodes(sart_path:str, lookback:int): # ->List[Episode]
    """Create a fake episode at position 0 for every episode already collected
    
    Since the cart probably was never at position 0 exactly, we add a set of Episodes with this transition.
    """
    import glob
    from psipy.rl.io.sart import SARTReader
    from psipy.rl.io.batch import Episode
    
    more_episodes = []

    for path in glob.glob(f"{sart_path}/*.hdf5"):
        with SARTReader(path) as reader:
            o, a, t, c = reader.load_full_episode()

            # Add episode full of goal states
            o = o.copy()
            a = a.copy()
            o[:, 0] = 0
            o[:, 1] = 0
            # A swinging pole affects the position of the cart,
            # so we say no swing here as well
            o[:, 2] = 180
            o[:, 3] = 0
            a[:] = 0  # The cart should not move once in this position
            more_episodes.append(Episode(o, a, t, c, lookback=lookback))
            
    return more_episodes

def costfunc(states:np.ndarray): # -> np.ndarray
    """Recalculate costs on all states provided
    
    This calculates costs on multiple states, so it returns an array
    """
    from psipy.rl.control.nfq import tanh2
    # Position is already defined against 0 (relative)
    position = states[:, 0]
    cost = tanh2(position, C=.2, mu=.1)
    
    return cost

Let's set everything up until we need to add the fake transitions.

In [None]:
from psipy.rl.plant.gym.cartpole_plants import (
    CartPoleSwayContAction,
    CartPoleSwayContinuousPlant,
    CartPoleSwayState,
)
from psipy.rl.control.nfqca import NFQCA


plant = CartPoleSwayContinuousPlant()  # Note that it is instantiated!
action = CartPoleSwayContAction
state = CartPoleSwayState
lookback = 5
sart_folder = "tutorial-advanced-sart"


def make_actor(inputs, lookback):
    inp = tfkl.Input((inputs, lookback), name="state_actor")
    net = tfkl.Flatten()(inp)
    net = tfkl.Dense(40, activation="tanh")(net)
    net = tfkl.Dense(40, activation="tanh")(net)
    net = tfkl.Dense(1, activation="tanh")(net)
    return tf.keras.Model(inp, net, name="actor")


def make_critic(inputs, lookback):
    inp = tfkl.Input((inputs, lookback), name="state_critic")
    act = tfkl.Input((1,), name="act_in")
    net = tfkl.Concatenate()([tfkl.Flatten()(inp), tfkl.Flatten()(act)])
    net = tfkl.Dense(40, activation="tanh")(net)
    net = tfkl.Dense(40, activation="tanh")(net)
    net = tfkl.Dense(1, activation="sigmoid")(net)
    return tf.keras.Model([inp, act], net, name="critic")

actor = make_actor(len(state.channels()), lookback)
critic = make_critic(len(state.channels()), lookback)

controller = NFQCA(
    actor=actor, 
    critic=critic, 
    state_channels=state.channels(), 
    action=CartPoleSwayContAction,
    lookback=lookback
)

loop = Loop(plant, controller, "CartPolePosition", sart_folder)

The critic recieves the cost function, and before we start training we want to add our fake goal position transitions.  Therefore, we put `costfunc` in `fit_critic` and append fake episodes before we start the fitting cycles.

We will print the size of the batch so it can be seen how the fake episodes increase the batch episode count.

In [None]:
num_cycles = 3
iterations = 3

loop.run(10)
batch = Batch.from_hdf5(sart_folder, lookback=lookback, control=controller)
print(f"Current batch size: {batch.num_episodes}")
# We now append the fake created episodes
# We could also throw away the created batch and only have the episodes
# created in the function by doing:
# batch = Batch(create_fake_episodes(sart_folder, lookback))
batch = batch.append(create_fake_episodes(sart_folder, lookback))
        
for cycle in range(num_cycles):
    LOG.info("Cycle: %d", cycle + 1)
    print(f"Current batch size: {batch.num_episodes}")
    controller.fit_normalizer(batch.observations, method="meanstd")

    for iteration in range(iterations):
        LOG.info("NFQCA Iteration: %d", iteration + 1)
        controller.fit_critic(batch,
                         iterations=1,
                         # We add the cost function here
                         costfunc=costfunc,
                         epochs=10,
                         minibatch_size=-1,
                         gamma=1.0,
                         verbose=0)
        controller.fit_actor(batch,
                        epochs=10,
                        minibatch_size=-1,
                        verbose=0)

    loop = Loop(plant, controller, "GrowingBatch", sart_folder)
    loop.run(10)
    batch.append_from_hdf5(sart_folder)

Let's see if the controller moves to the goal position.  Oh the tension!

In [None]:
loop = Loop(plant, controller, "GrowingBatchEval", f"live-{sart_folder}", render=True)
loop.run(3)

Congratulations, you've made it through the Advanced Tutorial.  You are now an expert!

## Please delete the SART logs created by this tutorial if you no longer need them.  Or run the cell below to do that automatically.

In [None]:
import shutil
dirs = ["tutorial-advanced-sart", 
        "tutorial-batch-sart", 
        "tutorial-growing-batch", 
        "live-tutorial-growing-sart", 
        "live-tutorial-batch-sart", 
        "live-tutorial-advanced-sart"]
for d in dirs:
    try:
        shutil.rmtree(d)
    except Exception as e:
        print(e)