<a href="https://colab.research.google.com/github/siriusted/gym-dssat-notebooks/blob/master/SB3_example_with_env_vars.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Gym-DSSAT x Stable-Baselines3 Tutorial

Welcome to a brief introduction to using gym-dssat with stable-baselines3.

For a background or more details about using stable-baselines3 for reinforcement learning, please take a look [here](https://github.com/araffin/rl-tutorial-jnrr19/tree/sb3)

In this notebook, we will assume familiarity with reinforcement learning and stable-baselines3. Thus the focus is on interacting with gym-dssat using SB3.

Next we proceed with installations

**Note**: It will take a while

# Installation
- gym_dssat

In [1]:
!wget https://raw.githubusercontent.com/siriusted/gym-dssat-notebooks/master/install.sh
!chmod u+x install.sh

--2022-02-22 12:25:49--  https://raw.githubusercontent.com/siriusted/gym-dssat-notebooks/master/install.sh
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2162 (2.1K) [text/plain]
Saving to: ‘install.sh’


2022-02-22 12:25:49 (31.0 MB/s) - ‘install.sh’ saved [2162/2162]



In [2]:
!./install.sh

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
[01m[K/content/gym_dssat_pdi/dssat-csm-os/Plant/SAMUCA-Sugarcane/SC_SAMUCA_MODEL.f90:325:9:[m[K

     real  hex_min                                     ! 
         [01;35m[K1[m[K
[01m[K/content/gym_dssat_pdi/dssat-csm-os/Plant/SAMUCA-Sugarcane/SC_SAMUCA_MODEL.f90:326:9:[m[K

     real  hour                                        ! 
         [01;35m[K1[m[K
[01m[K/content/gym_dssat_pdi/dssat-csm-os/Plant/SAMUCA-Sugarcane/SC_SAMUCA_MODEL.f90:327:9:[m[K

     real  ini_dw_lf_phy                               ! 
         [01;35m[K1[m[K
[01m[K/content/gym_dssat_pdi/dssat-csm-os/Plant/SAMUCA-Sugarcane/SC_SAMUCA_MODEL.f90:328:9:[m[K

     real  ini_la                                      ! 
         [01;35m[K1[m[K
[01m[K/content/gym_dssat_pdi/dssat-csm-os/Plant/SAMUCA-Sugarcane/SC_SAMUCA_MODEL.f90:329:9:[m[K

     real  init_leaf_area                              ! 
         [01;35m[K1[m[K


In [4]:
!cd gym_dssat_pdi/ && git branch

* [32mdev[m
  stable[m


- stable_baselines3

In [3]:
!pip install stable-baselines3[extra]

Collecting stable-baselines3[extra]
  Downloading stable_baselines3-1.4.0-py3-none-any.whl (176 kB)
[K     |████████████████████████████████| 176 kB 5.1 MB/s 
Collecting atari-py==0.2.6
  Downloading atari_py-0.2.6-cp37-cp37m-manylinux1_x86_64.whl (2.8 MB)
[K     |████████████████████████████████| 2.8 MB 43.2 MB/s 
Reason for being yanked: re-release with new wheels[0m
Installing collected packages: stable-baselines3, atari-py
  Attempting uninstall: atari-py
    Found existing installation: atari-py 0.2.9
    Uninstalling atari-py-0.2.9:
      Successfully uninstalled atari-py-0.2.9
Successfully installed atari-py-0.2.6 stable-baselines3-1.4.0


All set! 

Next, we will train a PPO agent using stable-baselines3. This agent will be compared to two hardcoded agents, namely a Null agent and an Expert agent.

To use gym-dssat properly, we need to run commands using `pdirun`. As a result, the source code for the rest of the tutorial has been collected into a script, which we will fetch next, then run.

In [None]:
!wget https://raw.githubusercontent.com/siriusted/gym-dssat-notebooks/master/sb_example.py

Run example. Note that this will take some time. You can take a look at `sb_example.py` in the file browser here on colab

In [5]:
!/opt/pdi/bin/pdirun python sb_example.py > result.out

Next, we will load and display the results from the script which have been stored in `results.pkl`

In [None]:
import pickle
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt

def plot_results(data):
    data_dict = {}
    for label, returns in data:
        data_dict[label] = returns
    df = pd.DataFrame(data_dict)
    
    ax = sns.boxplot(data=df)
    ax.set_xlabel("policy")
    ax.set_ylabel("evaluation output")
    plt.show()

with open("results.pkl", "rb") as result_file:
    results = pickle.load(result_file)

In [None]:
plot_results(results)

All done! Go ahead and edit `sb_example.py` in the integrated editor, then re-run the code cell above to observe results

In [16]:
# setup env variables like in pdirun
import os
PDI_DIR = '/opt/pdi'

os.environ['PATH'] += f'/:{PDI_DIR}/bin/'
os.environ['LD_LIBRARY_PATH'] += f'/:{PDI_DIR}/lib/'
os.environ['PYTHONPATH'] += f'/:{PDI_DIR}/lib/python3/dist-packages/'
os.environ['LIBRARY_PATH'] += f'/:{PDI_DIR}/lib/'
os.environ['CPATH'] = f'{PDI_DIR}/include/:{PDI_DIR}/lib/pdi/finclude/GNU-7.5'

In [21]:
!pdirun

Environment loaded for PDI version 1.5.0-alpha.2022-02-18.fef10ca


In [22]:
import gym
import gym_dssat_pdi
import random
import torch
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, VecCheckNan
from stable_baselines3.common.callbacks import EvalCallback
torch.autograd.set_detect_anomaly(True)

<torch.autograd.anomaly_mode.set_detect_anomaly at 0x7fd1eb5d45d0>

In [23]:
# helpers for action normalization
def normalize_action(action_space_limits, action):
    """Normalize the action from [low, high] to [-1, 1]"""
    low, high = action_space_limits
    return 2.0 * ((action - low) / (high - low)) - 1.0

def denormalize_action(action_space_limits, action):
    """Denormalize the action from [-1, 1] to [low, high]"""
    low, high = action_space_limits
    return low + (0.5 * (action + 1.0) * (high - low))

# Wrapper for easy and uniform interfacing with SB3
class GymDssatWrapper(gym.Wrapper):
    def __init__(self, env):
        super(GymDssatWrapper, self).__init__(env)

        self.action_low, self.action_high = self._get_action_space_bounds()

        # using a normalized action space
        self.action_space = gym.spaces.Box(low=-1, high=1, shape=(1,), dtype="float32")

        # using a vector representation of observations to allow
        # easily using SB3 MlpPolicy
        self.observation_space = gym.spaces.Box(low=0.0,
                                                high=np.inf,
                                                shape=env.observation_dict_to_array(
                                                    env.observation).shape,
                                                dtype="float32"
                                                )

        # to avoid annoying problem with Monitor when episodes end and things are None
        self.last_info = {}
        self.last_obs = None

    def _get_action_space_bounds(self):
        box = self.env.action_space['anfer']
        return box.low, box.high

    def _format_action(self, action):
        return { 'anfer': action[0] }

    def _format_observation(self, observation):
        return self.env.observation_dict_to_array(observation)

    def reset(self):
        return self._format_observation(self.env.reset())


    def step(self, action):
        # Rescale action from [-1, 1] to original action space interval
        denormalized_action = denormalize_action((self.action_low, self.action_high), action)
        formatted_action = self._format_action(denormalized_action)
        obs, reward, done, info = self.env.step(formatted_action)

        # handle `None`s in obs, reward, and info on done step
        if done:
            obs, reward, info = self.last_obs, 0, self.last_info
        else:
            self.last_obs = obs
            self.last_info = info

        formatted_observation = self._format_observation(obs)
        return formatted_observation, reward, done, info

    def close(self):
        return self.env.close()

    def seed(self, seed):
        self.env.set_seed(seed)

    def eval(self):
        self.env.set_evaluation()


    def __del__(self):
        self.close()

In [24]:
# Create environment
env_args = {
    'run_dssat_location': '/opt/dssat_pdi/run_dssat',
    'mode': 'fertilization',
    'seed': 123,
    'random_weather': True,
}

env = Monitor(GymDssatWrapper(gym.make('GymDssatPdi-v0', **env_args)))

UnregisteredEnv: ignored