<a href="https://colab.research.google.com/github/siriusted/gym-dssat-notebooks/blob/master/SB3_example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Gym-DSSAT x Stable-Baselines3 Tutorial

Welcome to a brief introduction to using gym-dssat with stable-baselines3.

For a background or more details about using stable-baselines3 for reinforcement learning, please take a look [here](https://github.com/araffin/rl-tutorial-jnrr19/tree/sb3)

In this notebook, we will assume familiarity with reinforcement learning and stable-baselines3. Thus the focus is on interacting with gym-dssat using SB3.

Next we proceed with installations

**Note**: It will take a while

# Installation
- gym_dssat

In [None]:
!wget https://raw.githubusercontent.com/siriusted/gym-dssat-notebooks/master/install.sh
!chmod u+x install.sh
!./install.sh

- stable_baselines3

In [None]:
!pip install stable-baselines3[extra] &> /dev/null

All set! 

Next, we will train a PPO agent using stable-baselines3. This agent will be compared to two hardcoded agents, namely a Null agent and an Expert agent.

To use gym-dssat properly, we need to run commands using `pdirun`. As a result, the source code for the rest of the tutorial has been collected into a script, which we will fetch next, then run. 

Additionally, we will fetch a pretrained model that is also evaluated in the script.

In [None]:
!wget https://raw.githubusercontent.com/siriusted/gym-dssat-notebooks/master/sb_example.py
!wget https://raw.githubusercontent.com/siriusted/gym-dssat-notebooks/master/pretrained_model.zip

Run example. Note that this will take some time. You can take a look at `sb_example.py` in the file browser here on colab

In [None]:
!/opt/pdi/bin/pdirun python sb_example.py

Next, we will load and display the results from the script which have been stored in `results.pkl`

In [None]:
import pickle
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt

def plot_results(data):
    data_dict = {}
    for label, returns in data:
        data_dict[label] = returns
    df = pd.DataFrame(data_dict)
    
    ax = sns.boxplot(data=df)
    ax.set_xlabel("policy")
    ax.set_ylabel("evaluation output")
    plt.show()

with open("results.pkl", "rb") as result_file:
    results = pickle.load(result_file)

In [None]:
plot_results(results)

All done! Go ahead and edit `sb_example.py` in the integrated editor, then re-run the code cell above to observe results