In [None]:
import mlagents
from mlagents_envs.environment import UnityEnvironment as UE
from mlagents_envs.envs.unity_parallel_env import UnityParallelEnv as UPZBE
from SAC_Distillation.DistilledSACAgent import DistilledSAC
from SAC_Distillation.Trajectories import ExperienceBuffer
from Hyperparameters import HYPERPARAMS as params
import numpy as np
import torch
import wandb

In [None]:
wandb.init(project="SAC_Distillation", entity="fede-")
wandb.config.update(params['sac_distilled'])
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
wandb.config.update({"device": device})

In [None]:
def relocate_agents(env):
    return list(env.agents)  # simplified

# New helper to extract observation data for an agent
def get_agent_obs(obs, agent):
    agent_data = obs[agent]
    return np.array(agent_data[1]), np.array(agent_data[2])

In [None]:
env = UE(file_name="DroneFlightv1", seed=1, side_channels=[], no_graphics_monitor=True, no_graphics=True)
env = UPZBE(env)

In [None]:
agents = relocate_agents(env)
print(agents)

In [None]:
Buffer = ExperienceBuffer(env.observation_space(agents[0])[1].shape, env.observation_space(agents[0])[2].shape,env.action_space(agents[0]).shape, params['sac_distilled'])

In [None]:
brain = DistilledSAC(env.observation_space(agents[0])[1].shape, env.observation_space(agents[0])[2].shape, env.action_space(agents[0]).shape, params['sac_distilled'], device)

In [None]:
for s in range(1, params['ppo_distilled'].seed_episodes + 1):
    obs, done, t = env.reset(), [False for _ in env.agents], 0
    while not all(done) or t < params['ppo_distilled'].n_steps_random_exploration:
        actions = {}
        log_probs = {}
        values = {}
        agents = relocate_agents(env)
        for agent in agents:
            # actions[agent] = env.action_space(agent).sample()
            if agent not in obs.keys():
                continue
            obs1, obs2 = get_agent_obs(obs, agent)
            actions[agent], log_probs[agent], values[agent] = brain.get_action(obs1, obs2)
            t+=1

        obs, reward, done, _ = env.step(actions)
        for agent in agents:
            if agent not in obs.keys():
                continue
            obs1, obs2 = get_agent_obs(obs, agent)
            Buffer.add(obs1, obs2, actions[agent], reward[agent], done[agent], log_prob=log_probs[agent], value=values[agent])
        done = [done[agent] for agent in agents if agent in done.keys()]
    print(f'Finished episode {s}')

Buffer.compute_advantages_and_returns()
print("Finished Rnd Exploration")
env.close()

In [None]:
env = UE(file_name="DroneFlightv1", seed=1, side_channels=[], no_graphics_monitor=True, no_graphics=True)
env = UPZBE(env)

In [None]:
agents = relocate_agents(env)
print(agents)

In [None]:
brain = DistilledSAC(env.observation_space(agents[0])[1].shape, env.observation_space(agents[0])[2].shape, env.action_space(agents[0]).shape, params['sac_distilled'], device)

In [None]:
brain.fine_tune_teacher(Buffer)

In [None]:
steps = 0
best_mean_reward = -np.inf
not_improved = 0
while steps < params['sac_distilled'].max_steps:
    obs, done, t = env.reset(), [False for _ in env.agents], 0
    episode_reward = 0
    while not all(done) or t < params['sac_distilled'].n_steps:
        actions = {}
        log_probs = {}
        values = {}
        agents = relocate_agents(env)
        for agent in agents:
            if agent not in obs.keys():
                continue
            obs1, obs2 = get_agent_obs(obs, agent)
            actions[agent], log_probs[agent], values[agent] = brain.get_action(obs1, obs2)
            t += 1

        obs, reward, done, _ = env.step(actions)
        for agent in agents:
            if agent not in obs.keys():
                continue
            obs1, obs2 = get_agent_obs(obs, agent)
            Buffer.add(obs1, obs2, actions[agent], reward[agent], done[agent], log_prob=log_probs[agent], value=values[agent])
        done = [done[agent] for agent in agents if agent in done.keys()]
        tot_reward = [reward[agent] for agent in agents if agent in reward.keys()]
    obs_keys = list(obs.keys())
    _, _, last_values = brain.get_action(obs[obs_keys[-1]][1], obs[obs_keys[-1]][2])
    Buffer.add_final_state(obs[obs_keys[-1]][1], obs[obs_keys[-1]][2], last_values)
    mean_reward = np.mean(tot_reward)
    
    steps += t

    # SAC optimization step
    brain.train(steps, Buffer)
    
    brain.optimizer = brain.improv_lr(brain.optimizer, params['sac_distilled'].lr, steps, params['sac_distilled'].n_steps)
    brain.optimizer_distill = brain.improv_lr(brain.optimizer_distill, params['sac_distilled'].lr, steps, params['sac_distilled'].n_steps)
    wandb.log({"Mean Reward": mean_reward, "Steps": steps})
env.close()

In [None]:
env.close()