In [1]:
from utilities.Network import Network
from utilities.ReplayBuffer import ReplayBuffer

import wandb
import json
import matplotlib.pyplot as plt
import torch
from Discrete_SAC_Agent import SACAgent

from mlagents_envs.environment import UnityEnvironment
from mlagents_envs.side_channel.engine_configuration_channel import EngineConfigurationChannel
from mlagents_envs.base_env import ActionTuple

from debug_side_channel import DebugSideChannel
from gym import spaces

import torch.nn.functional as F
from torch.distributions import Normal
import sys
import numpy as np
import pandas as pd


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
#ALPHA_INITIAL = 1.
#DISCOUNT_RATE = 0.99
#SOFT_UPDATE_INTERPOLATION_FACTOR = 0.01
ALPHA_INITIAL = 1.
REPLAY_BUFFER_BATCH_SIZE = 132
DISCOUNT_RATE = 0.00
LEARNING_RATE = 10 ** -4
SOFT_UPDATE_INTERPOLATION_FACTOR = 0.99
TRAINING_EVALUATION_RATIO = 4
RUNS = 2
EPISODES_PER_RUN = 400
STEPS_PER_EPISODE = 200
WANDB = True

In [3]:
if WANDB:
    wandb.init(
        project="visibility-game",
    )

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mr-marr747[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [4]:
class SACAgent:

    def __init__(self, environment):
        self.environment = environment
        self.LEARNING_RATE = LEARNING_RATE
        self.ALPHA_INITIAL = ALPHA_INITIAL
        self.REPLAY_BUFFER_BATCH_SIZE = REPLAY_BUFFER_BATCH_SIZE
        self.DISCOUNT_RATE = DISCOUNT_RATE
        self.state_dim = 3#self.environment.observation_space.shape[0]
        self.action_dim = self.environment.action_space.n
        self.critic_local = Network(input_dimension=self.state_dim,
                                    output_dimension=self.action_dim)
        self.critic_local2 = Network(input_dimension=self.state_dim,
                                     output_dimension=self.action_dim)
        self.critic_optimiser = torch.optim.Adam(self.critic_local.parameters(), lr=self.LEARNING_RATE)
        self.critic_optimiser2 = torch.optim.Adam(self.critic_local2.parameters(), lr=self.LEARNING_RATE)

        self.critic_target = Network(input_dimension=self.state_dim,
                                     output_dimension=self.action_dim)
        self.critic_target2 = Network(input_dimension=self.state_dim,
                                      output_dimension=self.action_dim)

        self.soft_update_target_networks(tau=1.)

        self.actor_local = Network(
            input_dimension=self.state_dim,
            output_dimension=self.action_dim,
            output_activation=torch.nn.Softmax(dim=1)
        )
        self.actor_optimiser = torch.optim.Adam(self.actor_local.parameters(), lr=self.LEARNING_RATE)

        self.replay_buffer = ReplayBuffer(self.environment)

        self.target_entropy = 0.98 * -np.log(1 / self.environment.action_space.n)
        self.log_alpha = torch.tensor(np.log(self.ALPHA_INITIAL), requires_grad=True)
        self.alpha = self.log_alpha
        self.alpha_optimiser = torch.optim.Adam([self.log_alpha], lr=self.LEARNING_RATE)

    def get_next_action(self, state, evaluation_episode=False):
        if evaluation_episode:
            discrete_action = self.get_action_deterministically(state)
        else:
            discrete_action = self.get_action_nondeterministically(state)
        return discrete_action

    def get_action_nondeterministically(self, state):
        action_probabilities = self.get_action_probabilities(state)
        discrete_action = np.random.choice(range(self.action_dim), p=action_probabilities)
        return discrete_action

    def get_action_deterministically(self, state):
        action_probabilities = self.get_action_probabilities(state)
        discrete_action = np.argmax(action_probabilities)
        return discrete_action

    def train_on_transition(self, state, discrete_action, next_state, reward, done):
        transition = (state, discrete_action, reward, next_state, done)
        self.train_networks(transition)

    def test_networks(self, minibatch):

        minibatch_separated = list(map(list, zip(*minibatch)))

        # unravel transitions to get states, actions, rewards and next states
        states_tensor = torch.tensor(np.array(minibatch_separated[0]), dtype=torch.float32)
        actions_tensor = torch.tensor(np.array(minibatch_separated[1]).astype(np.int64))
        rewards_tensor = torch.tensor(np.array(minibatch_separated[2])).float() * 100
        next_states_tensor = torch.tensor(np.array(minibatch_separated[3]), dtype=torch.float32)
        done_tensor = torch.tensor(np.array(minibatch_separated[4]))

        critic_loss, critic2_loss = \
            self.critic_loss(states_tensor, actions_tensor, rewards_tensor, next_states_tensor, done_tensor)
            
        if WANDB:
            wandb.log({"critic_loss": critic_loss})
            wandb.log({"critic2_loss": critic2_loss})

        critic_loss.backward()
        critic2_loss.backward()
        self.critic_optimiser.step()
        self.critic_optimiser2.step()

        actor_loss, log_action_probabilities = self.actor_loss(states_tensor)

        if WANDB:
            wandb.log({"actor_loss": actor_loss})

        actor_loss.backward()
        self.actor_optimiser.step()

        alpha_loss = self.temperature_loss(log_action_probabilities)

        alpha_loss.backward()
        self.alpha_optimiser.step()
        self.alpha = self.log_alpha.exp()

        self.soft_update_target_networks()

    def train_networks(self, transition):
        # Set all the gradients stored in the optimisers to zero.
        self.critic_optimiser.zero_grad()
        self.critic_optimiser2.zero_grad()
        self.actor_optimiser.zero_grad()
        self.alpha_optimiser.zero_grad()
        # Calculate the loss for this transition.
        self.replay_buffer.add_transition(transition)
        # Compute the gradients based on this loss, i.e. the gradients of the loss with respect to the Q-network
        # parameters.
        if self.replay_buffer.get_size() >= self.REPLAY_BUFFER_BATCH_SIZE:
            # get minibatch of 100 transitions from replay buffer
            minibatch = self.replay_buffer.sample_minibatch(self.REPLAY_BUFFER_BATCH_SIZE)
            #mb = [(
                #transition[0].tolist(),
                #transition[1],
                #transition[2],
                #transition[3].tolist(), 
                #transition[4]) for transition in minibatch.tolist()]
            #with open('minibatch_debug.json', 'w') as file:
                #json.dump(mb, file)
            minibatch_separated = list(map(list, zip(*minibatch)))

            # unravel transitions to get states, actions, rewards and next states
            states_tensor = torch.tensor(np.array(minibatch_separated[0]))
            actions_tensor = torch.tensor(np.array(minibatch_separated[1]).astype(np.int64))
            rewards_tensor = torch.tensor(np.array(minibatch_separated[2])).float() * 100
            next_states_tensor = torch.tensor(np.array(minibatch_separated[3]), dtype=torch.float32)
            done_tensor = torch.tensor(np.array(minibatch_separated[4]))

            critic_loss, critic2_loss = \
                self.critic_loss(states_tensor, actions_tensor, rewards_tensor, next_states_tensor, done_tensor)
            if WANDB:
                wandb.log({"critic_loss": critic_loss})
                wandb.log({"critic2_loss": critic2_loss})

            critic_loss.backward()
            critic2_loss.backward()
            self.critic_optimiser.step()
            self.critic_optimiser2.step()

            actor_loss, log_action_probabilities = self.actor_loss(states_tensor)

            if WANDB:
                wandb.log({"actor_loss": actor_loss})

            actor_loss.backward()
            self.actor_optimiser.step()

            alpha_loss = self.temperature_loss(log_action_probabilities)

            alpha_loss.backward()
            self.alpha_optimiser.step()
            self.alpha = self.log_alpha.exp()

            self.soft_update_target_networks()

    def critic_loss(self, states_tensor, actions_tensor, rewards_tensor, next_states_tensor, done_tensor):
        with torch.no_grad():
            action_probabilities, log_action_probabilities = self.get_action_info(next_states_tensor)
            next_q_values_target = self.critic_target.forward(next_states_tensor)
            next_q_values_target2 = self.critic_target2.forward(next_states_tensor)
            soft_state_values = (action_probabilities * (
                    torch.min(next_q_values_target, next_q_values_target2) - self.alpha * log_action_probabilities
            )).sum(dim=1)

            next_q_values = rewards_tensor + ~done_tensor * self.DISCOUNT_RATE*soft_state_values

        soft_q_values = self.critic_local(states_tensor).gather(1, actions_tensor.unsqueeze(-1)).squeeze(-1)
        soft_q_values2 = self.critic_local2(states_tensor).gather(1, actions_tensor.unsqueeze(-1)).squeeze(-1)
        critic_square_error = torch.nn.MSELoss(reduction="none")(soft_q_values, next_q_values)
        critic2_square_error = torch.nn.MSELoss(reduction="none")(soft_q_values2, next_q_values)
        weight_update = [min(l1.item(), l2.item()) for l1, l2 in zip(critic_square_error, critic2_square_error)]
        self.replay_buffer.update_weights(weight_update)
        critic_loss = critic_square_error.mean()
        critic2_loss = critic2_square_error.mean()
        return critic_loss, critic2_loss

    def actor_loss(self, states_tensor):
        action_probabilities, log_action_probabilities = self.get_action_info(states_tensor)
        q_values_local = self.critic_local(states_tensor)
        q_values_local2 = self.critic_local2(states_tensor)
        inside_term = self.alpha * log_action_probabilities - torch.min(q_values_local, q_values_local2)
        policy_loss = (action_probabilities * inside_term).sum(dim=1).mean()
        return policy_loss, log_action_probabilities

    def temperature_loss(self, log_action_probabilities):
        alpha_loss = -(self.log_alpha * (log_action_probabilities + self.target_entropy).detach()).mean()
        return alpha_loss

    def get_action_info(self, states_tensor):
        action_probabilities = self.actor_local.forward(states_tensor)
        z = action_probabilities == 0.0
        z = z.float() * 1e-8
        log_action_probabilities = torch.log(action_probabilities + z)
        return action_probabilities, log_action_probabilities

    def get_action_probabilities(self, state):
        state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
        action_probabilities = self.actor_local.forward(state_tensor)
        return action_probabilities.squeeze(0).detach().numpy()

    def soft_update_target_networks(self, tau=SOFT_UPDATE_INTERPOLATION_FACTOR):
        self.soft_update(self.critic_target, self.critic_local, tau)
        self.soft_update(self.critic_target2, self.critic_local2, tau)

    def soft_update(self, target_model, origin_model, tau):
        for target_param, local_param in zip(target_model.parameters(), origin_model.parameters()):
            target_param.data.copy_(tau * local_param.data + (1 - tau) * target_param.data)

    def predict_q_values(self, state):
        q_values = self.critic_local(state)
        q_values2 = self.critic_local2(state)
        return torch.min(q_values, q_values2)

In [5]:
class Env():
    def __init__(self, config):
        self.observation_space = spaces.Tuple((spaces.Discrete(10), spaces.Discrete(1), spaces.Discrete(10))) 
        self.action_space = spaces.Discrete(5) 
        self.engine_channel = EngineConfigurationChannel()
        self.debug_channel = DebugSideChannel()
        self.env = UnityEnvironment(file_name=config['unity_environment'], 
                                    side_channels=[self.engine_channel, self.debug_channel])
        self.env.reset()
        self.engine_channel.set_configuration_parameters(time_scale=config['time_scale'])
        self.behavior_registry = []
        self.behavior_registry.append(list(self.env.behavior_specs.keys())[0])

    def get_state(self):
        # need to figure out negative reward later
        behavior_name = self.behavior_registry[0]
        decision_steps, terminal_steps = self.env.get_steps(behavior_name)
        state = decision_steps.obs[0][0]
        return state
        
    def step(self, action):
        behavior_name = self.behavior_registry[0]
        action_tuple = ActionTuple()
        action_tuple.add_discrete(action.reshape(1, 1))
        self.env.set_actions(behavior_name, action_tuple)
        self.env.step()
        reward = 0.
        done = False
        decision_steps, terminal_steps = self.env.get_steps(behavior_name)
        if len(terminal_steps.reward) > 0:
            if terminal_steps.reward[0] > 0:
                print('win')
                reward = 1.
        if len(terminal_steps) > 0:
            done = True 
        decision_steps, terminal_steps = self.env.get_steps(behavior_name)
        next_state = decision_steps.obs[0][0]
        if len(terminal_steps.reward) > 0:
            if terminal_steps.reward[0] > 0:
                next_state = self.debug_channel.get_last_state()
        return reward, next_state, done

In [8]:
transition = [
 ([4.5       , 0.5       , 0.5       ], 4, 0., [4.5       , 0.5       , 0.5       ], False),
 ([5.5       , 0.5       , 0.5       ], 3, 0., [5.5       , 0.5       , 0.5       ], False),
 ([4.5       , 0.4999999 , 1.4999998 ], 3, 0., [4.5       , 0.49999985, 1.4999998 ], False),
 ([4.5       , 0.4999999 , 1.4999998 ], 3, 0., [4.5       , 0.4999999 , 1.4999998 ], False),
 ([4.5       , 0.49999994, 1.5       ], 3, 0., [4.5       , 0.4999999 , 1.5       ], False),
 ([4.5       , 0.4999999 , 1.4999998 ], 3, 0., [4.5       , 0.4999999 , 1.4999998 ], False),
 ([5.5       , 0.4999999 , 1.4999998 ], 3, 0., [5.5       , 0.4999999 , 1.4999998 ], False),
 ([4.5       , 0.4999999 , 1.4999998 ], 1, 0., [4.5       , 0.49999985, 1.4999998 ], False),
 ([5.5       , 0.4999999 , 1.4999998 ], 2, 0., [5.5       , 0.4999999 , 1.4999998 ], False),
 ([6.5       , 0.4999999 , 1.4999998 ], 0, 0., [6.5       , 0.4999999 , 1.4999998 ], False),
 ([4.5       , 0.49999985, 1.4999998 ], 3, 0., [4.5       , 0.4999999 , 1.4999998 ], False),
 ([5.5       , 0.5       , 0.5       ], 3, 0., [5.5       , 0.5       , 0.5       ], False),
 ([4.5       , 0.5       , 0.5       ], 2, 0., [4.5       , 0.5       , 0.5       ], False),
 ([4.5       , 0.49999994, 1.5       ], 3, 0., [4.5       , 0.4999999 , 1.5       ], False),
 ([5.5       , 0.4999999 , 1.4999998 ], 1, 0., [5.5       , 0.4999999 , 1.4999998 ], False),
 ([4.5       , 0.49999994, 1.5       ], 3, 0., [4.5       , 0.4999999 , 1.5       ], False),
 ([5.5       , 0.4999999 , 1.4999998 ], 3, 0., [5.5       , 0.4999999 , 1.4999998 ], False),
 ([5.5       , 0.4999999 , 1.4999998 ], 2, 0., [5.5       , 0.4999999 , 1.4999998 ], False),
 ([6.5       , 0.4999999 , 1.4999998 ], 3, 0., [6.5       , 0.4999999 , 1.4999998 ], False),
 ([6.5       , 0.4999999 , 1.4999998 ], 0, 0., [6.5       , 0.4999999 , 1.4999998 ], False),
 ([5.5       , 0.4999999 , 1.4999998 ], 3, 0., [5.5       , 0.4999999 , 1.4999998 ], False),
 ([6.5       , 0.4999999 , 1.4999998 ], 3, 0., [6.5       , 0.4999999 , 1.4999998 ], False),
 ([4.5       , 0.4999999 , 1.4999998 ], 3, 0., [4.5       , 0.49999985, 1.4999998 ], False),
 ([5.5       , 0.5       , 0.5       ], 3, 0., [5.5       , 0.5       , 0.5       ], False),
 ([4.5       , 0.49999985, 1.4999998 ], 3, 0., [4.5       , 0.4999999 , 1.4999998 ], False),
 ([4.5       , 0.4999999 , 1.4999998 ], 3, 0., [4.5       , 0.4999999 , 1.4999998 ], False),
 ([5.5       , 0.4999999 , 1.4999998 ], 0, 0., [5.5       , 0.4999999 , 1.4999998 ], False),
 ([5.5       , 0.4999999 , 1.4999998 ], 1, 0., [5.5       , 0.4999999 , 1.4999998 ], False),
 ([4.5       , 0.4999999 , 1.4999998 ], 2, 0., [4.5       , 0.49999985, 1.4999998 ], False),
 ([4.5       , 0.4999999 , 1.4999998 ], 3, 0., [4.5       , 0.4999999 , 1.4999998 ], False),
 ([6.5       , 0.4999999 , 1.4999998 ], 4, 0., [6.5       , 0.4999999 , 1.4999998 ], False),
 ([5.5       , 0.4999999 , 1.4999998 ], 1, 0., [5.5       , 0.4999999 , 1.4999998 ], False),
 ([5.5       , 0.4999999 , 1.4999998 ], 2, 0., [5.5       , 0.4999999 , 1.4999998 ], False),
 ([5.5       , 0.5       , 0.5       ], 3, 0., [5.5       , 0.5       , 0.5       ], False),
 ([4.5       , 0.4999999 , 1.4999998 ], 3, 0., [4.5       , 0.4999999 , 1.4999998 ], False),
 ([4.5       , 0.49999985, 1.4999998 ], 3, 0., [4.5       , 0.4999999 , 1.4999998 ], False),
 ([4.5       , 0.4999999 , 1.5       ], 4, 0., [4.5       , 0.49999994, 1.5       ], False),
 ([4.5       , 0.5       , 0.5       ], 2, 0., [4.5       , 0.5       , 0.5       ], False),
 ([5.5       , 0.4999999 , 1.4999998 ], 1, 0., [5.5       , 0.4999999 , 1.4999998 ], False),
 ([4.5       , 0.4999999 , 1.5       ], 3, 0., [4.5       , 0.49999994, 1.5       ], False),
 ([4.5       , 0.4999999 , 1.4999998 ], 0, 0., [4.5       , 0.49999985, 1.4999998 ], False),
 ([4.5       , 0.4999999 , 1.5       ], 2, 0., [4.5       , 0.49999994, 1.5       ], False),
 ([4.5       , 0.49999994, 1.5       ], 3, 0., [4.5       , 0.4999999 , 1.5       ], False),
 ([5.5       , 0.4999999 , 1.4999998 ], 3, 0., [5.5       , 0.4999999 , 1.4999998 ], False),
 ([6.5       , 0.4999999 , 1.4999998 ], 4, 0., [6.5       , 0.4999999 , 1.4999998 ], False),
 ([5.5       , 0.4999999 , 1.4999998 ], 1, 0., [5.5       , 0.4999999 , 1.4999998 ], False),
 ([4.5       , 0.4999999 , 1.4999998 ], 1, 0., [5.5       , 0.4999999 , 1.4999998 ], False),
 ([5.5       , 0.4999999 , 1.4999998 ], 0, 0., [5.5       , 0.4999999 , 1.4999998 ], False),
 ([6.5       , 0.4999999 , 1.4999998 ], 3, 0., [6.5       , 0.4999999 , 1.4999998 ], False),
 ([5.5       , 0.5       , 1.5       ], 2, 0., [5.5       , 0.5       , 1.5       ], False),
 ([4.5       , 0.4999999 , 1.4999998 ], 1, 0., [4.5       , 0.49999985, 1.4999998 ], False),
 ([4.5       , 0.49999994, 1.5       ], 3, 0., [4.5       , 0.4999999 , 1.5       ], False),
 ([4.5       , 0.4999999 , 1.5       ], 0, 0., [4.5       , 0.49999994, 1.5       ], False),
 ([5.5       , 0.5       , 0.5       ], 3, 0., [5.5       , 0.5       , 0.5       ], False),
 ([4.5       , 0.4999999 , 1.4999998 ], 0, 0., [4.5       , 0.4999999 , 1.4999998 ], False),
 ([5.5       , 0.5       , 1.5       ], 1, 0., [5.5       , 0.5       , 1.5       ], False),
 ([5.5       , 0.4999999 , 1.4999998 ], 4, 0., [5.5       , 0.4999999 , 1.4999998 ], False),
 ([4.5       , 0.4999999 , 1.4999998 ], 3, 0., [4.5       , 0.49999985, 1.4999998 ], False),
 ([5.5       , 0.4999999 , 1.4999998 ], 3, 0., [5.5       , 0.4999999 , 1.4999998 ], False),
 ([4.5       , 0.4999999 , 1.5       ], 3, 0., [4.5       , 0.49999994, 1.5       ], False),
 ([5.5       , 0.4999999 , 1.4999998 ], 3, 0., [5.5       , 0.4999999 , 1.4999998 ], False),
 ([5.5       , 0.5       , 0.5       ], 2, 0., [5.5       , 0.5       , 0.5       ], False),
 ([4.5       , 0.4999999 , 1.4999998 ], 3, 0., [4.5       , 0.4999999 , 1.4999998 ], False),
 ([4.5       , 0.49999985, 1.4999998 ], 0, 0., [4.5       , 0.4999999 , 1.4999998 ], False),
 ([5.5       , 0.4999999 , 1.4999998 ], 3, 0., [5.5       , 0.4999999 , 1.4999998 ], False),
 ([5.5       , 0.4999999 , 1.4999998 ], 1, 0., [5.5       , 0.4999999 , 1.4999998 ], False),
 ([6.5       , 0.4999999 , 1.4999998 ], 2, 0., [5.5       , 0.4999999 , 1.4999998 ], False),
 ([5.5       , 0.4999999 , 1.4999998 ], 2, 0., [5.5       , 0.4999999 , 1.4999998 ], False),
 ([5.5       , 0.5       , 1.5       ], 2, 0., [5.5       , 0.5       , 1.5       ], False),
 ([5.5       , 0.4999999 , 1.4999998 ], 3, 0., [5.5       , 0.4999999 , 1.4999998 ], False),
 ([4.5       , 0.49999985, 1.4999998 ], 3, 0., [4.5       , 0.4999999 , 1.4999998 ], False),
 ([4.5       , 0.49999994, 1.5       ], 3, 0., [4.5       , 0.4999999 , 1.5       ], False),
 ([4.5       , 0.4999999 , 1.4999998 ], 1, 0., [4.5       , 0.4999999 , 1.4999998 ], False),
 ([5.5       , 0.4999999 , 1.4999998 ], 1, 0., [5.5       , 0.4999999 , 1.4999998 ], False),
 ([5.5       , 0.5       , 0.5       ], 3, 0., [5.5       , 0.5       , 1.5       ], False),
 ([5.5       , 0.5       , 1.5       ], 0, 0., [5.5       , 0.5       , 1.5       ], False),
 ([4.5       , 0.5       , 0.5       ], 2, 0., [4.5       , 0.5       , 0.5       ], False),
 ([6.5       , 0.4999999 , 1.4999998 ], 1, 0., [6.5       , 0.4999999 , 1.4999998 ], False),
 ([5.5       , 0.4999999 , 1.4999998 ], 0, 0., [5.5       , 0.4999999 , 1.4999998 ], False),
 ([4.5       , 0.4999999 , 1.5       ], 2, 0., [4.5       , 0.49999994, 1.5       ], False),
 ([6.5       , 0.4999999 , 1.4999998 ], 1, 0., [6.5       , 0.4999999 , 1.4999998 ], False),
 ([4.5       , 0.4999999 , 1.5       ], 4, 0., [4.5       , 0.5       , 0.5       ], False),
 ([4.5       , 0.49999994, 1.5       ], 3, 0., [4.5       , 0.4999999 , 1.5       ], False),
 ([4.5       , 0.49999994, 1.5       ], 0, 0., [4.5       , 0.4999999 , 1.5       ], False),
 ([4.5       , 0.4999999 , 1.4999998 ], 0, 0., [4.5       , 0.49999985, 1.4999998 ], False),
 ([5.5       , 0.4999999 , 1.4999998 ], 3, 0., [5.5       , 0.4999999 , 1.4999998 ], False),
 ([4.5       , 0.4999999 , 1.5       ], 1, 0., [4.5       , 0.49999994, 1.5       ], False),
 ([5.5       , 0.5       , 1.5       ], 0, 0., [5.5       , 0.5       , 1.5       ], False),
 ([5.5       , 0.4999999 , 1.4999998 ], 3, 0., [5.5       , 0.4999999 , 1.4999998 ], False),
 ([4.5       , 0.49999985, 1.4999998 ], 2, 0., [4.5       , 0.4999999 , 1.4999998 ], False),
 ([4.5       , 0.4999999 , 1.4999998 ], 4, 0., [4.5       , 0.49999985, 1.4999998 ], False),
 ([4.5       , 0.49999994, 1.5       ], 0, 0., [4.5       , 0.4999999 , 1.5       ], False),
 ([5.5       , 0.5       , 1.5       ], 2, 0., [4.5       , 0.49999994, 1.5       ], False),
 ([4.5       , 0.4999999 , 1.4999998 ], 0, 0., [4.5       , 0.49999985, 1.4999998 ], False),
 ([4.5       , 0.49999994, 1.5       ], 0, 0., [4.5       , 0.4999999 , 1.5       ], False),
 ([4.5       , 0.4999999 , 1.4999998 ], 1, 0., [4.5       , 0.4999999 , 1.4999998 ], False),
 ([5.5       , 0.5       , 1.5       ], 1, 0., [5.5       , 0.5       , 1.5       ], False),
 ([5.5       , 0.4999999 , 1.4999998 ], 2, 0., [5.5       , 0.4999999 , 1.4999998 ], False),
 ([6.5       , 0.4999999 , 1.4999998 ], 0, 0., [6.5       , 0.4999999 , 1.4999998 ], False),
 ([5.5       , 0.5       , 0.5       ], 3, 0., [5.5       , 0.5       , 0.5       ], False),
 
 ([6.5       , 0.5       , 1.5       ], 3, 1., [6.5       , 0.5       , 2.5       ], True),
 ([6.5       , 0.5       , 1.5       ], 3, 1., [6.5       , 0.5       , 2.5       ], True),
 ([6.5       , 0.5       , 1.5       ], 3, 1., [6.5       , 0.5       , 2.5       ], True),
 ([6.5       , 0.5       , 1.5       ], 3, 1., [6.5       , 0.5       , 2.5       ], True),
 ([6.5       , 0.5       , 1.5       ], 3, 1., [6.5       , 0.5       , 2.5       ], True),
 ([6.5       , 0.5       , 1.5       ], 3, 1., [6.5       , 0.5       , 2.5       ], True),
 ([6.5       , 0.5       , 1.5       ], 3, 1., [6.5       , 0.5       , 2.5       ], True),
 ([6.5       , 0.5       , 1.5       ], 3, 1., [6.5       , 0.5       , 2.5       ], True),
 
 ]

In [9]:
while True:
    agent.test_networks(transition)

KeyboardInterrupt: 

In [17]:
print(agent.critic_local(torch.tensor([6.5, 0.5, 1.5])))
print(agent.critic_local2(torch.tensor([6.5, 0.5, 1.5])))

tensor([-5.8553, -4.0970,  1.4242, 74.7129,  0.9605], grad_fn=<AddBackward0>)
tensor([ 3.9447,  5.5563,  6.1680, 44.5385,  4.8056], grad_fn=<AddBackward0>)


In [6]:
driver_config = {'unity_environment': 'C:\\Users\\rmarr\\Documents\\ml-agents-dodgeball-env-ICT',
                'time_scale': 1.0}
environment = Env(driver_config)
for run in range(RUNS):
    agent = SACAgent(environment)
    run_results = []
    run_reward = 0
    for episode_number in range(EPISODES_PER_RUN):
        #print('\r', f'Run: {run + 1}/{RUNS} | Episode: {episode_number + 1}/{EPISODES_PER_RUN}', end=' ')
        evaluation_episode = episode_number % TRAINING_EVALUATION_RATIO == 0
            
        environment.env.reset()
        state = environment.get_state()

        done = False
        i = 0
        while not done and i < STEPS_PER_EPISODE:
            i += 1
            action = agent.get_next_action(state, evaluation_episode=evaluation_episode)
            reward, next_state, done = environment.step(action)
            run_reward = run_reward + 1 
            if not evaluation_episode:
                agent.train_on_transition(state, action, next_state, reward, done)
            state = next_state

    agent_results.append(run_results)
    wandb.log({"runs wins": run_reward})

KeyboardInterrupt: 