In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from tqdm import tqdm 
import random
from itertools import product

import wandb
import yaml

from ADR_Environment import ADR_Environment
from rl_glue import RLGlue
from Complete_pytorch import Transition, ReplayBuffer, ActionValueNetwork


class Agent():
    """ 
    Deep Q-Learning Agent with Experience Replay 
    """
    def __init__(self):
        self.name = 'dqn'

    def agent_init(self, agent_config):
        self.device = agent_config['device']
        self.replay_buffer = ReplayBuffer(agent_config["replay_buffer_size"],
                                          agent_config["minibatch_size"]
                                          )
        self.policy_network = ActionValueNetwork(agent_config["network_config"]).to(self.device)
        self.target_network = ActionValueNetwork(agent_config["network_config"]).to(self.device)
        self.optimizer = torch.optim.AdamW(self.policy_network.parameters(),
                              lr = agent_config["optimizer_config"]["step_size"],
                              betas = (agent_config["optimizer_config"]["beta_m"], agent_config["optimizer_config"]["beta_v"]),
                              eps = agent_config["optimizer_config"]["epsilon"],
                              amsgrad = True)
        self.num_actions = agent_config["network_config"]["num_actions"]
        self.num_replay = agent_config["num_replay_updates_per_step"]
        self.discount = agent_config["gamma"]
        self.tau = agent_config["tau"]
        self.last_state = None
        self.last_action = None
        self.sum_rewards = 0
        self.episode_steps = 0
        self.ep_loss = 0

    def policy(self, state):
        """
        Args:
            - state (Numpy array): the state.
        Returns:
            - the action (int).
        """
        state = torch.tensor(state, device = self.device, dtype = torch.float32)
        action = self.policy_network.select_action(state, self.tau)
        return action

    def agent_start(self, state):
        """
        The first method called when the experiment starts, called after
        the environment starts.
        Args:
            - state (Numpy array): the state from the
                environment's env_start function.
        Returns:
            - The first action the agent takes (int).
        """

        #print("State in agent_start =" , state[1])

        self.ep_loss = 0
        self.sum_rewards = 0
        self.episode_steps = 0
        self.last_state = np.array([state[1]])

        self.last_action = self.policy(self.last_state)
        return self.last_action

    def agent_step(self, reward, state):
        """
        A step taken by the agent.
        Args:
            - reward (float): the reward received for taking the last action taken
            - state (Numpy array): the state from the
                environment's step based, where the agent ended up after the
                last step
        Returns:
            - The action the agent is taking (int).
        """
        self.sum_rewards += reward
        self.episode_steps += 1
        state = np.array([state])
        #print('***' * 50)
        #print("state  = ",state)
        #print('***' * 50)

        action = self.policy(state)
        self.replay_buffer.append(self.last_state , self.last_action , reward , 0 , state)
        if self.replay_buffer.size() > self.replay_buffer.minibatch_size:
            self.target_network.load_state_dict(self.policy_network.state_dict())
            for _ in range(self.num_replay):
                experiences = self.replay_buffer.sample()
                self.optimize_network(experiences)
        self.last_state = state
        self.last_action = action
        return action

    def agent_end(self, reward):
        """
        Run when the agent terminates.
        Args:
            - reward (float): the reward the agent received for entering the
                terminal state.
        """
        self.sum_rewards += reward
        self.episode_steps += 1
        state = np.zeros_like(self.last_state)
        self.replay_buffer.append(self.last_state , self.last_action , reward , 1 , state)
        if self.replay_buffer.size() > self.replay_buffer.minibatch_size:
            self.target_network.load_state_dict(self.policy_network.state_dict())
            for _ in range(self.num_replay):
                experiences = self.replay_buffer.sample()
                self.optimize_network(experiences) 

    def agent_message(self, message):
        if message == "get_sum_reward":
            return self.sum_rewards
        elif message == "get_loss":
            return self.ep_loss
        else:
            raise Exception("Unrecognised Message !")

    def optimize_network(self, experiences):
        """
        Optimize the policy network using the experiences
        """        
        # transpose the batch from batch-array of transitions to Transition of batch-array
        batch = Transition(*zip(*experiences))

        # transform all batches into tensors
        non_final_mask = torch.tensor([not s for s in batch.terminal], device=self.device, dtype=torch.int64)
        non_final_next_states = torch.tensor(batch.next_state, device=self.device, dtype=torch.float32).squeeze(1)
        state_batch = torch.tensor(batch.state, device=self.device, dtype=torch.float32).squeeze(1)
        action_batch = torch.tensor(batch.action, device=self.device).unsqueeze(-1)
        reward_batch = torch.tensor(batch.reward, device=self.device, dtype=torch.float32)

        # Compute Q(s_t, a)
        state_action_values = self.policy_network(state_batch).gather(1, action_batch)

        # Compute V(s_{t+1}) for all next states.
        next_state_values = torch.zeros(self.replay_buffer.minibatch_size, device=self.device)
        with torch.no_grad():
            next_state_values[non_final_mask] = self.target_network(non_final_next_states).max(1).values

        # Compute the expected Q values (TD targets)
        expected_state_action_values = (next_state_values * self.discount) + reward_batch

        # Compute Huber loss
        criterion = nn.SmoothL1Loss()
        loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))

        # Optimize the model
        self.optimizer.zero_grad()
        loss.retain_grad()
        loss.backward()
        self.ep_loss += loss.to(torch.int32)

        # In-place gradient clipping
        torch.nn.utils.clip_grad_value_(self.policy_network.parameters(), 100) 
        self.optimizer.step()






def run_experiment(environment , agent , environment_parameters , agent_parameters , experiment_parameters):
    """
    Run the experiment
    """

    rl_glue = RLGlue(environment, agent)
    agent_sum_reward = np.zeros((experiment_parameters["num_runs"],
                                 experiment_parameters["num_episodes"]))
    env_info = {}
    agent_info = agent_parameters
    for run in range(1 , experiment_parameters["num_runs"]+1):
        agent_info["seed"] = 0 #run
        agent_info["network_config"]["seed"] = 0 #run
        env_info["seed"] = 0 #run
        rl_glue.rl_init(agent_info , env_info)

        seed = agent_info["seed"]
        torch.manual_seed(seed)
        np.random.seed(seed)
        random.seed(seed)
        gpu_use = experiment_parameters['gpu_use']
        if gpu_use and torch.cuda.is_available():
            torch.cuda.manual_seed(seed)
            torch.cuda.manual_seed_all(seed)
            torch.backends.cudnn.deterministic = True
            torch.backends.cudnn.benchmark = False

        ep_count = 0
        for episode in range(1 , experiment_parameters["num_episodes"]+1):
            ep_count += 1
            #environment.pass_count(environment, message=f"Ep : {ep_count}")
            rl_glue.rl_episode(experiment_parameters["timeout"])

            # Get data from episode
            episode_reward = rl_glue.rl_agent_message("get_sum_reward")
            ep_loss = rl_glue.rl_agent_message(("get_loss"))
            fuel_limit, time_limit, impossible_dt, impossible_binary_flag = rl_glue.environment.get_term_reason()
            avg_fuel_used = rl_glue.environment.get_fuel_use_average()
            avg_time_used = rl_glue.environment.get_time_use_average()
            agent_sum_reward[run - 1, episode - 1] = episode_reward
            
            wandb.log({
                    "episode loss": ep_loss,
                    "episode reward": episode_reward ,
                    "fuel limit": fuel_limit,
                    "time limit": time_limit,
                    "impossible_dt": impossible_dt,
                    "impossible_binary_flag": impossible_binary_flag,
                    "average fuel used":avg_fuel_used,
                    "average time used":avg_time_used
                })


    wandb.log({"avg_reward": sum(agent_sum_reward[0])/experiment_parameters["num_episodes"]})

In [None]:
def main():
    a = wandb.init()
    weights_file = None #'models/test_weights.pth'
    experiment_parameters = {"num_runs":1,
                                "num_episodes":2000,
                                "timeout":2000,
                                "gpu_use":True,
                                "track_wandb":False}
    environment_parameters = {}
    current_env = ADR_Environment
    agent_parameters = {"network_config":{"state_dim":25,
                                            "num_hidden_units":512,
                                            "num_actions":300,
                                            "weights_file":weights_file},
                            "optimizer_config":{"step_size": wandb.config.learning_rate, # working value 1e-3
                                                "beta_m":0.9,
                                                "beta_v":0.999,
                                                "epsilon":1e-8},
                            "replay_buffer_size":wandb.config.replay_buffer_size,
                            "minibatch_size":wandb.config.minibatch_size,
                            "num_replay_updates_per_step": wandb.config.replay_updates_per_step,
                            "gamma":wandb.config.gamma,
                            "tau":wandb.config.tau,
                            "seed":0
                            }

    # Set device
    gpu_use = experiment_parameters['gpu_use']
    
    if gpu_use and torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")

    agent_parameters['device'] = device
    print(device)
    current_agent = Agent

    run_experiment(current_env, current_agent, environment_parameters, agent_parameters, experiment_parameters)

## Run sweeping

In [None]:
with open("./sweep_config/grid.yaml") as file: # change file name to use different sweep
    sweep_configuration = yaml.load(file, Loader=yaml.FullLoader)

In [None]:
sweep_id = wandb.sweep(sweep=sweep_configuration, project="HPO-ADR")

In [None]:
wandb.agent(sweep_id, function=main) # count = 20 if bayesian search, nothing if grid search