# Outlook

This notebook is designed to understand how to use a gymnasium environment as a BBRL agent in practice, using autoreset=True.
It is part of the [BBRL documentation](https://github.com/osigaud/bbrl/docs/index.html).

If this is your first contact with BBRL, you may start be having a look at [this more basic notebook](01-basic_concepts.student.ipynb) and [the one using autoreset=False](02-multi_env_noautoreset.student.ipynb).

## Installation and Imports

The BBRL library is [here](https://github.com/osigaud/bbrl).

Below, we import standard python packages, pytorch packages and gymnasium environments.

In [None]:
# Installs the necessary Python and system libraries
try:
    from easypip import easyimport, easyinstall, is_notebook
except ModuleNotFoundError as e:
    get_ipython().run_line_magic("pip", "install easypip")
    from easypip import easyimport, easyinstall, is_notebook

easyinstall("bbrl>=0.2.2")
easyinstall("swig")
easyinstall("bbrl_gymnasium>=0.2.0")
easyinstall("bbrl_gymnasium[classic_control]")

In [None]:
import os
import sys
from pathlib import Path
import math

from moviepy.editor import ipython_display as video_display
import time
from tqdm.auto import tqdm
from typing import Tuple, Optional
from functools import partial

from omegaconf import OmegaConf
import torch
import bbrl_gymnasium

import copy
from abc import abstractmethod, ABC
import torch.nn as nn
import torch.nn.functional as F
from time import strftime
OmegaConf.register_new_resolver(
    "current_time", lambda: strftime("%Y%m%d-%H%M%S"), replace=True
)

In [None]:
# Imports all the necessary classes and functions from BBRL
from bbrl.agents.agent import Agent
from bbrl import get_arguments, get_class, instantiate_class
# The workspace is the main class in BBRL, this is where all data is collected and stored
from bbrl.workspace import Workspace

# Agents(agent1, agent2, agent3, ...) executes the different agents the one after the other
# TemporalAgent(agent) executes an agent over multiple timesteps in the workspace, 
# or until a given condition is reached

from bbrl.agents import Agents, TemporalAgent
from bbrl.agents.gymnasium import ParallelGymAgent, make_env

# Replay buffers are useful to store past transitions when training
from bbrl.utils.replay_buffer import ReplayBuffer

## Definition of agents

As before, we first create an Agent representing [the CartPole-v1 gym environment](https://gymnasium.farama.org/environments/classic_control/cart_pole/).
This is done using the [ParallelGymAgent](https://github.com/osigaud/bbrl/blob/40fe0468feb8998e62c3cd6bb3a575fef88e256f/src/bbrl/agents/gymnasium.py#L261) class.

## Single environment case

We start with a Random Agent and a single instance of the CartPole environment

In [None]:
# We deal with 1 environment at a time (random seed 2139)

env_agent = ParallelGymAgent(partial(make_env, env_name='CartPole-v1'), 1).seed(2139)
obs_size, action_dim = env_agent.get_obs_and_actions_sizes()
print(f"Environment: observation space in R^{obs_size} and action space R^{action_dim}")

class RandomAgent(Agent):
    def __init__(self, action_dim):
        super().__init__()
        self.action_dim = action_dim

    def forward(self, t: int, choose_action=True, **kwargs):
        """An Agent can use self.workspace"""
        obs = self.get(("env/env_obs", t))
        action = torch.randint(0, self.action_dim, (len(obs), ))
        self.set(("action", t), action)

# Each agent will be run (in the order given when constructing Agents)
agents = Agents(env_agent, RandomAgent(action_dim))
t_agents = TemporalAgent(agents)

Let us have a closer look at the content of the workspace

In [None]:
# Creates a new workspace
workspace = Workspace() 
t_agents(workspace, stop_variable="env/done")

# We get the transitions: each tensor is transformed so
# that: 
# - we have the value at time step t and t+1 (so all the tensors first dimension have a size of 2)
# - there is no distinction between the different environments (here, there is just one environment run in parallel to make it easy)
transitions = workspace.get_transitions()

# You can see that each pair of actions in the transitions can be found in the workspace
display("Observations (first 3)", workspace["env/env_obs"][:3, 0])

display("Transitions of actions (first 3)")
for t in range(3):
    display(f'(s_{t}, s_{t+1})')
    display(transitions["env/env_obs"][:, t])

## Multiple environment case

Now we are using 3 environments.
Given the organization of transitions, to find the transitions of a particular environment
we have to watch in the transition every 3 lines, since transitions are stored one environment after the other.

## The replay buffer

Differently from the previous case, we use a replace buffer that stores
a set of transitions $(s_t, a_t, r_t, s_{t+1})$
Finally, the replay buffer keeps slices [:, i, ...] of the transition
workspace (here at most 100 transitions)

In [None]:
rb = ReplayBuffer(max_size=100)

# We add the transitions to the buffer....
rb.put(transitions)

# And sample from them here we get 3 tuples (s_t, s_{t+1})
rb.get_shuffled(3)["env/env_obs"]

A transition workspace is still a workspace... this is quite
 handy since each transition can be seen as a mini-episode of two time steps;
 we can use our agents on it:

In [None]:
# Just as a reference

display(transitions["action"])

t_random_agent = TemporalAgent(RandomAgent(action_dim))
t_random_agent(transitions, t=0, n_steps=2)

# Here, the action tensor will have been overwritten by the new actions
display(transitions["action"])