In [1]:
import mlagents
from mlagents_envs.environment import UnityEnvironment as UE
from mlagents_envs.envs.unity_parallel_env import UnityParallelEnv as UPZBE
from SAC_Distillation.DistilledSACAgent import DistilledSAC
from SAC_Distillation.Trajectories import SAC_ExperienceBuffer
from Hyperparameters import HYPERPARAMS as params
import numpy as np
import torch
import wandb

In [2]:
wandb.init(project="SAC_Distillation", entity="fede-")
wandb.config.update(params['sac_distilled'])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
wandb.config.update({"device": device})

[34m[1mwandb[0m: Currently logged in as: [33mrullofederico16[0m ([33mfede-[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


In [3]:
def relocate_agents(env):
    return list(env.agents)  # simplified

# New helper to extract observation data for an agent
def get_agent_obs(obs, agent):
    agent_data = obs[agent]
    return np.array(agent_data[1]), np.array(agent_data[2])

In [4]:
env = UE(file_name="DroneFlightv1", seed=1, side_channels=[], no_graphics_monitor=True, no_graphics=True)
env = UPZBE(env)

In [5]:
agents = relocate_agents(env)
print(agents)

['Drone?team=0?agent_id=0', 'Drone?team=0?agent_id=1', 'Drone?team=0?agent_id=10', 'Drone?team=0?agent_id=11', 'Drone?team=0?agent_id=2', 'Drone?team=0?agent_id=3', 'Drone?team=0?agent_id=4', 'Drone?team=0?agent_id=5', 'Drone?team=0?agent_id=6', 'Drone?team=0?agent_id=7', 'Drone?team=0?agent_id=8', 'Drone?team=0?agent_id=9']


In [6]:
obs = env.reset()
print(obs[agents[0]][1])
possible_actions = env.action_space(agents[0]).sample()
print(f"Possible actions: {possible_actions}")
print(env.action_space(agents[0]).shape)
print(env.observation_space(agents[0])[1].shape)
print(env.observation_space(agents[0])[2].shape)

[[[0.8039215 0.8039215 0.8039215 ... 0.8039215 0.8039215 0.8039215]
  [0.8039215 0.8039215 0.8039215 ... 0.8039215 0.8039215 0.8039215]
  [0.8039215 0.8039215 0.8039215 ... 0.8039215 0.8039215 0.8039215]
  ...
  [0.8039215 0.8039215 0.8039215 ... 0.8039215 0.8039215 0.8039215]
  [0.8039215 0.8039215 0.8039215 ... 0.8039215 0.8039215 0.8039215]
  [0.8039215 0.8039215 0.8039215 ... 0.8039215 0.8039215 0.8039215]]]
Possible actions: [ 0  1  0 -1  0  0]
(6,)
(1, 84, 84)
(24,)


In [7]:
Buffer = SAC_ExperienceBuffer(env.observation_space(agents[0])[1].shape, env.observation_space(agents[0])[2].shape,env.action_space(agents[0]).shape, params['sac_distilled'])

In [8]:
brain = DistilledSAC(env.observation_space(agents[0])[1].shape, env.observation_space(agents[0])[2].shape, env.action_space(agents[0]).shape,len(agents), params['sac_distilled'])

In [9]:
for s in range(1, 2):
    obs, done, t = env.reset(), [False for _ in env.agents], 0
    while not all(done) or t < 10:
        actions = {}
        log_probs = {}
        values = {}
        agents = relocate_agents(env)
        for agent in agents:
            actions[agent] = env.action_space(agent).sample()
            if agent not in obs.keys():
                continue
            obs1, obs2 = get_agent_obs(obs, agent)
            v1,v2 = brain.get_values(obs1,obs2, actions[agent])
            values[agent] = torch.min(v1,v2)


        next_obs, reward, done, _ = env.step(actions)

        for agent in agents:
            if agent not in next_obs.keys():
                continue
            next_obs1, next_obs2 = get_agent_obs(next_obs, agent)
            Buffer.store(obs1, obs2, actions[agent], reward[agent], next_obs1, next_obs2, done[agent], values[agent])
        obs = next_obs
        done = [done[agent] for agent in agents if agent in done.keys()]
        t += 1
    print(f'Finished episode {s}')

# Buffer.compute_advantages()
print("Finished Rnd Exploration")
env.close()

Finished episode 1
Finished Rnd Exploration


In [10]:
env = UE(file_name="DroneFlightv1", seed=1, side_channels=[], no_graphics_monitor=True, no_graphics=True)
env = UPZBE(env)
agents = relocate_agents(env)

In [11]:
brain.train(Buffer)

  critic_loss = F.mse_loss(q1, target_q)
  context_layer = torch.nn.functional.scaled_dot_product_attention(


In [None]:
steps = 0
best_mean_reward = -np.inf
while steps < params['sac_distilled'].max_steps:
    obs, done, t = env.reset(), [False for _ in env.agents], 0
    episode_reward = 0
    while not all(done) or t < params['sac_distilled'].n_steps:
        actions = {}
        log_probs = {}
        values = {}
        agents = relocate_agents(env)
        for agent in agents:
            actions[agent] = env.action_space(agent).sample()
            if agent not in obs.keys():
                continue
            obs1, obs2 = get_agent_obs(obs, agent)
            v1,v2 = brain.get_values(obs1,obs2, actions[agent])
            values[agent] = torch.min(v1,v2)


        next_obs, reward, done, _ = env.step(actions)

        for agent in agents:
            if agent not in next_obs.keys():
                continue
            next_obs1, next_obs2 = get_agent_obs(next_obs, agent)
            Buffer.store(obs1, obs2, actions[agent], reward[agent], next_obs1, next_obs2, done[agent], values[agent])
        obs = next_obs
        done = [done[agent] for agent in agents if agent in done.keys()]
        tot_reward = [reward[agent] for agent in agents if agent in reward.keys()]
        t += 1
        
    obs_keys = list(obs.keys())
    mean_reward = np.mean(tot_reward)
    
    steps += t

    # SAC optimization step
    brain.train(steps, Buffer)

    Buffer.compute_advantages()
    brain.actor_optimizer = brain.adjust_lr(brain.actor_optimizer, params['sac_distilled'].lr, steps, params['sac_distilled'].n_steps)
    brain.critic_1_optimizer = brain.adjust_lr(brain.critic_1_optimizer, params['sac_distilled'].lr, steps, params['sac_distilled'].n_steps)
    brain.distill_optimizer = brain.adjust_lr(brain.distill_optimizer, params['sac_distilled'].lr, steps, params['sac_distilled'].n_steps)
    wandb.log({"Mean Reward": mean_reward, "Steps": steps})
env.close()

KeyboardInterrupt: 

In [None]:
env.close()