# Exploring RLlib for solar agent problems

This notebook explores the use of RLlib with the solar agent environment. It is partly based on the [cartpole tutorial notebook by anyscale](https://github.com/anyscale/academy/blob/9317775c393aff06cff06ae58c88f85ce201940d/ray-rllib/explore-rllib/01-Application-Cart-Pole.ipynb).

## 0. Setup

In [None]:
%load_ext autoreload
%autoreload 2
%config IPCompleter.greedy=True

In [None]:
import ray
import ray.tune
import ray.rllib
import json
import glob
import os
import pandas as pd
import numpy as np
import ipywidgets as widgets
import matplotlib.pyplot as plt

from solara.constants import PROJECT_PATH
import solara.envs.components.solar
import solara.envs.components.load
import solara.envs.components.grid
import solara.envs.battery_control
import solara.envs.components.battery

In [None]:
## Initialising ray (starts background process for distributed computing)
ray.shutdown()
ray.init()

## 1. Setting up the solar agent environment

In [None]:
# To make the environment usable with RLlib
# we wrap its creation in functon
def battery_env_creator(env_config):
    pv_data_path = PROJECT_PATH + "/data/solar_trace_data/PV_5796.txt"
    load_data_path = PROJECT_PATH + "/data/solar_trace_data/load_5796.txt"

    # Setting up components of environment
    battery_model = solara.envs.components.battery.LithiumIonBattery(20, "NMC", 1/10.0)
    pv_model = solara.envs.components.solar.DataPV(data_path=pv_data_path)
    load_model = solara.envs.components.load.DataLoad(data_path=load_data_path)
    grid_model = solara.envs.components.grid.PeakGrid()

    # Fixing load and PV trace to single sample
    episode_num = 12
    load_model.fix_start(episode_num)
    pv_model.fix_start(episode_num)

    env = solara.envs.battery_control.BatteryControlEnv(
        battery=battery_model,
        pv_system = pv_model,
        grid = grid_model,
        load = load_model,
    )
    
    return env

ray.tune.registry.register_env("battery_control", battery_env_creator)

## 2. Setting up the RLlib agent

In [None]:
save_path = "./tmp/ppo/battery-control"
check_save_path = save_path + "/checkpoints"
out_save_path = save_path + "/outputs"

trainer = ray.rllib.agents.ppo.PPOTrainer(env="battery_control", config={
    "framework": "torch",
    "env_config": {},
    "output": out_save_path,
    "output_compress_columns": [],
    "gamma": 0.9999999, # we set the discount factor very high
})

## 3. Training agent on environment

In [None]:
num_iterations = 10
iteration_string = "Training iteration: {}, Min reward: {:.3f}, Mean reward: {:.3f}, Max reward: {:.3f}."

for i in range(num_iterations):
    iteration_out = trainer.train()
    print(iteration_string.format(i,
                                  iteration_out['episode_reward_min'], 
                                  iteration_out['episode_reward_mean'],
                                  iteration_out['episode_reward_max']))

    file_name = trainer.save(check_save_path)

print("Training completed")

## 4. Visualising training

In [None]:
## Loading latest output json file

# Getting latest file path
list_of_out_files = glob.glob(out_save_path + "/*.json")
latest_out_file = max(list_of_out_files, key=os.path.getctime)

episode_trace_data = {"obs": [],"actions": [], }

# Note that each line is separate Json data.
# If the entire file is loaded at once as a Json, it is broken.
# Therefore each line needs to be loaded separately.
with open(latest_out_file) as file:
    for line in file:
        line_data = json.loads(line)
        for key in episode_trace_data.keys():
            episode_trace_data[key] += line_data[key]

In [None]:
## Creating plot
@widgets.interact(ep_num=(1,len(episode_trace_data["obs"])//24))
def plot_episode(ep_num, save_path="episode_plot.png"):

    df = pd.DataFrame(data=episode_trace_data["obs"][24*ep_num:24*(ep_num+1)])
    df.columns = ["load",
                "pv_generation",
                "battery__energy_content",
                "time_step",
                "sum_load",
                "sum_pv_gen",]

    ep_actions = episode_trace_data["actions"][24*ep_num:24*(ep_num+1)]
    df["actions"] = np.array(ep_actions)

    y_names = ["load",
                "pv_generation",
                "battery__energy_content",]
                #"sum_load",
                #"sum_pv_gen",]
            
    fig, axs = plt.subplots(2, 1, figsize=(8, 6), dpi=100)
    fig.suptitle('Episode {}'.format(ep_num))

    df.plot(x="time_step", y=y_names, ylim = [-0.1,5], color=["blue","green","black"], ax=axs[0])
    df.plot(x="time_step", y=["actions"],ylim = [-2,2], color='red', ax=axs[1])
    
    plt.savefig(save_path)