In [1]:
import mlagents
from mlagents_envs.environment import UnityEnvironment as UE
from mlagents_envs.envs.unity_parallel_env import UnityParallelEnv as UPZBE
from SAC_Distillation.DistilledSACAgent import DistilledSAC
from SAC_Distillation.Trajectories import SAC_ExperienceBuffer
from Hyperparameters import HYPERPARAMS as params
import numpy as np
import torch
import wandb

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
import os
os.environ["CUDA_LAUNCH_BLOCKING"]="1" # for debugging purposes

In [3]:
def relocate_agents(env):
    agents = list(set(env.agents))
    return agents  # simplified

# New helper to extract observation data for an agent
def get_agent_obs(obs, agent, *, cam_key=2, vec_key=1):
    data = obs[agent]
    if isinstance(data, dict) and "observation" in data:
        data = data["observation"]
        cam, vec = np.asarray(data[cam_key]), np.asarray(data[vec_key])
    else:
        cam, vec = np.asarray(data[cam_key]), np.asarray(data[vec_key])

    assert cam.ndim == 3, "Camera observation should be 3D array"
    assert vec.ndim in (1,2), "Vector observation should be 1D or 2D array"
    return cam, vec

In [4]:
env = UE(file_name="Env/DroneFlightv1", seed=1)
env = UPZBE(env)

In [5]:
agents = list(set(env.agents))

In [6]:
obs = env.reset()
# print("obs", obs)
# print("obs", obs['Drone?team=0?agent_id=1'][0].shape)
# print("obs", obs['Drone?team=0?agent_id=1'][1].shape)
# print("obs", obs['Drone?team=0?agent_id=1'][2].shape)
print(env.observation_space(agents[0])[0].shape)
print(env.observation_space(agents[0])[1].shape)
print(env.observation_space(agents[0])[2].shape)
print(env.observation_space(agents[0])[3].shape)
print(env.action_space(agents[0]).shape)
# print({a: env.action_space(a).sample() for a in agents})
# print(env.step({a: env.action_space(a).sample() for a in agents}))

(3, 84, 84)
(68,)
(4, 84, 84)
(64,)
(5,)


In [7]:
cam_shape = env.observation_space(agents[0])[2].shape
vec_shape = env.observation_space(agents[0])[1].shape
action_shape = env.action_space(agents[0]).shape

In [8]:
replay_buffer = SAC_ExperienceBuffer(cam_shape, vec_shape,action_shape, params['sac_distilled'])

In [9]:
agent = DistilledSAC(cam_shape, vec_shape, action_shape,len(agents), params['sac_distilled'])
agent.model.convolution_pipeline.load_state_dict(torch.load("SavedModels/feature_extractor_contrastive_init.pth"))
agent.model.convolution_pipeline.to(device)

Number of agents:  4
Action dimensions:  5


FeatureExtractionNet(
  (convolutional_pipeline): Sequential(
    (0): Conv2d(4, 32, kernel_size=(8, 8), stride=(4, 4), padding=(2, 2))
    (1): ReLU()
    (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): ReLU()
    (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): ReLU()
    (6): Flatten(start_dim=1, end_dim=-1)
  )
  (distilled_converter): Linear(in_features=6400, out_features=12800, bias=True)
  (dropout): Dropout(p=0.5, inplace=False)
)

In [10]:
N_AGENTS = len(agents)
cfg = params['sac_distilled']
RAND_STEPS = cfg.get("n_steps_random_exploration", 10_000)
SEED_EPISODES = cfg.get("seed_episodes", 5)
blank_cam = np.zeros(cam_shape, dtype=np.uint8)
blank_vec = np.zeros(vec_shape, dtype=np.float32)

In [None]:
tot_steps = RAND_STEPS * SEED_EPISODES
obs_dict = env.reset()

for ep in range(tot_steps):
    obs = env.reset()
    act_dict = {a: env.action_space(a).sample() for a in agents}

    # while not all(done_dict.values()) and step < RAND_STEPS:
    cam_now = np.empty((N_AGENTS, *cam_shape), dtype=np.uint8)
    vect_now = np.empty((N_AGENTS, *vec_shape), dtype=np.float32)
    act_now = np.empty((N_AGENTS, *action_shape), dtype=np.float32)

    for i, a in enumerate(agents):
        if a in obs_dict:
            cam, vec = get_agent_obs(obs_dict, a)
        else:
            cam, vec = blank_cam, blank_vec
        cam_now[i] = cam
        vect_now[i] = vec
        act_now[i] = act_dict[a]

    next_obs, rew_dict, done_dict, _ = env.step(act_dict)

    cam_next = np.empty_like(cam_now)
    vect_next = np.empty_like(vect_now)
    rew_now = np.zeros((N_AGENTS, 1), dtype=np.float32)
    done_now = np.zeros((N_AGENTS, 1), dtype=np.float32)

    for i, a in enumerate(agents):
        if a in next_obs:
            cam_n, vec_n = get_agent_obs(next_obs, a)
        else:
            cam_n, vec_n = blank_cam, blank_vec
        cam_next[i] = cam_n
        vect_next[i] = vec_n
        rew_now[i, 0] = rew_dict.get(a, 0.0)
        done_now[i, 0] = done_dict.get(a, False)
    replay_buffer.store_joint(
        cam_now, vect_now, act_now,
        rew_now,
        cam_next, vect_next,
        done_now
    )

    obs_dict = next_obs

print("Finished collecting random steps")



In [None]:
# replay_buffer.save('Collected_Experience/init_experience.npz')

In [None]:
# replay_buffer.load('Collected_Experience/init_experience.npz')

In [None]:
# agent.fine_tune_teacher(replay_buffer, epochs=cfg.get('train_epochs', 10))

In [None]:
# agent.teacher.save("SavedModels/Teacher")

In [None]:
# from pathlib import Path


# epochs = cfg.get('train_epochs', 10)
# batch_size = 128
# distill_lr = cfg.get('distill_lr', 1e-4)
# temperature = cfg.get('temperature', 0.07)
# num_frames = cfg.get('num_frames', 4_000)

# assert len(replay_buffer) >= batch_size, "Not enough data in the buffer to start training"
# print(f"Starting distillation training for {epochs} epochs with batch size {batch_size}")
# agent.offline_distill(
#     frame_buffer=replay_buffer,
#     epochs=epochs,
#     batch_size=batch_size,
#     lr=distill_lr,
#     temperature=temperature,
#     num_frames=num_frames,
# )
# out_dir = Path("SavedModels")
# out_dir.mkdir(exist_ok=True)


In [None]:
# torch.save(agent.model.convolution_pipeline.state_dict(), out_dir / "student_latest.pth")

In [None]:
# agent.model.convolution_pipeline.load_state_dict(torch.load("SavedModels/student_latest.pth"))

In [None]:
# agent.train(Buffer,step = params['sac_distilled'].seed_episodes*params['sac_distilled'].n_steps_random_exploration)

In [None]:
import datetime as dt
run_name = f"sac_distill_{dt.datetime.now():%Y%m%d_%H%M%S}"
wandb.init(
            project=os.getenv("WANDB_PROJECT", "SAC_Distillation"),
            entity =os.getenv("WANDB_ENTITY",  "fede-"),
            name   =run_name,
            config = {**params["sac_distilled"], "device": str(device)},
        )

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mrullofederico16[0m ([33mfede-[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
total_updates = 0
train_every = 4_096
log_every = 4_096
print_every = 10_000
max_steps = cfg.get("max_steps", 5_000_000)


ema_reward = 0.0
last_ema_reward = -np.inf
ema_alpha = cfg.get("ema_alpha", 0.01)


obs = env.reset()
steps=0
with torch.no_grad():
    cc_weights = agent.ccritic.q_heads[0].weight.abs().mean()

In [None]:
while steps < max_steps:
    if not obs:
        obs = env.reset()
        continue
    agents = relocate_agents(env)
    cam_now = np.zeros((N_AGENTS, *cam_shape), dtype=np.uint8)
    vect_now = np.zeros((N_AGENTS, *vec_shape), dtype=np.float32)

    for i, aid in enumerate(agents):
        if aid in obs:
            cam, vec = get_agent_obs(obs, aid)
        else:
            cam, vec = blank_cam, blank_vec
        cam_now[i]  = cam
        vect_now[i] = vec

    cam_t = torch.from_numpy(np.stack(cam_now)).float().to(device)
    vec_t = torch.from_numpy(np.stack(vect_now)).float().to(device)

    if cam_t.isnan().any() or vec_t.isnan().any():
        cam_t = torch.nan_to_num(cam_t)
        vec_t = torch.nan_to_num(vec_t)
    with torch.no_grad():
        step_fraction = steps / max_steps
        act_t = agent.act(cam_t, vec_t,step=steps)
    

    act_np = torch.round(act_t).clamp(-1,1).cpu().numpy()
    actions = {aid: action for aid, action in zip(agents, act_np)}
    

    next_obs, rew_dict, done_dict, infos = env.step(actions)
    for aid, r in rew_dict.items():
        if r > 1.0 or r < -1.0:
            print(f"agent={aid} raw_reward={r} done={done_dict[aid]}")
    steps += 1

    cam_next = np.zeros_like(cam_now)
    vect_next = np.zeros_like(vect_now)
    rew_now = np.zeros((N_AGENTS, 1), dtype=np.float32)
    done_now = np.zeros((N_AGENTS, 1), dtype=np.float32)

    goal_reached = 0.0
    crashed = 0.0

    for i, aid in enumerate(agents):
        if aid in next_obs:
            cam_n, vec_n = get_agent_obs(next_obs, aid)
        else:
            cam_n, vec_n = blank_cam, blank_vec
        cam_next[i] = cam_n
        vect_next[i] = vec_n

        r = rew_dict.get(aid, 0.0) + infos.get(aid, {}).get("group_reward", 0.0)
        rew_now[i, 0] = r
        done_now[i, 0] = done_dict.get(aid, False)

        goal_reached += 1 if r > 19 else 0
        crashed += 1 if r < -9 else 0

    replay_buffer.store_joint(
        cam_now, vect_now, act_np,
        rew_now,
        cam_next, vect_next,
        done_now
    )

    mean_r = np.mean(rew_now)
    ema_reward = ema_reward * (1- ema_alpha) + mean_r * ema_alpha

    if steps % train_every == 0:
        a_loss, c_loss = agent.train(replay_buffer, step=steps)
        total_updates += 1

        if c_loss > 1e6:
            agent.load("SavedModels/SAC_distilled_trained.pth")
            print("Critic loss exploded, reloading model")
        
        if ema_reward > last_ema_reward:
            last_ema_reward = ema_reward
            agent.save("SavedModels/SAC_distilled_trained.pth")
            print(f"New best EMA reward: {last_ema_reward:.2f}")

        wandb.log({
            "EMA Reward": ema_reward,
            "Mean Reward": mean_r,
            "Actor Loss": a_loss,
            "Critic Loss": c_loss,
            "Steps": steps,
            "Goal Reached": goal_reached,
            "Crashes": crashed,
        }, step=steps)

    obs = next_obs

env.close()

agent=Drone?team=0?agent_id=0 raw_reward=0.04800000041723251 done=False
agent=Drone?team=0?agent_id=1 raw_reward=0.04800000041723251 done=False
agent=Drone?team=0?agent_id=2 raw_reward=0.04800000041723251 done=False
agent=Drone?team=0?agent_id=3 raw_reward=0.06800000369548798 done=False
agent=Drone?team=0?agent_id=0 raw_reward=0.04800000041723251 done=False
agent=Drone?team=0?agent_id=1 raw_reward=0.04800000041723251 done=False
agent=Drone?team=0?agent_id=2 raw_reward=0.04800000041723251 done=False
agent=Drone?team=0?agent_id=3 raw_reward=0.06800000369548798 done=False
agent=Drone?team=0?agent_id=0 raw_reward=0.04800000041723251 done=False
agent=Drone?team=0?agent_id=1 raw_reward=0.04800000041723251 done=False
agent=Drone?team=0?agent_id=2 raw_reward=0.04800000041723251 done=False
agent=Drone?team=0?agent_id=3 raw_reward=0.06800000369548798 done=False
agent=Drone?team=0?agent_id=0 raw_reward=0.04800000041723251 done=False
agent=Drone?team=0?agent_id=1 raw_reward=0.04800000041723251 don

KeyboardInterrupt: 

In [None]:
# while steps < max_steps:
#     if len(obs) == 0:
#         obs = env.reset()   # poll again without acting
#         continue
#     active_agents = relocate_agents(env)
#     cams, vecs = [], []
#     for aid in active_agents:
#         cam, vec = get_agent_obs(obs, aid)
#         cams.append(cam)
#         vecs.append(vec)
#     if not cams:
#         obs = env.reset()
#         continue
#     cam_t = torch.as_tensor(np.stack(cams), device=device, dtype=torch.float32).unsqueeze(0)
#     vec_t = torch.as_tensor(np.stack(vecs), device=device, dtype=torch.float32).unsqueeze(0)

#     with torch.no_grad():
#         step_fraction = steps / max_steps
#         act_t = agent.get_action(cam_t, vec_t, train=False, add_noise=True, step_fraction=step_fraction)

#     act_np = torch.round(act_t).clamp(-1,1)
#     act_np = act_np[0].cpu().numpy()
#     actions = {a: act for a, act in zip(active_agents, act_np)}

#     next_obs, rewards, done_flags, infos = env.step(actions)
#     steps += 1

#     cam2, vec2, r_list, d_list = [], [], [], []
#     goal_reached, crashes = 0.0, 0.0
#     for aid in active_agents:
#         if aid in next_obs:
#             n_cam, n_vec = get_agent_obs(next_obs, aid)
#         else:
#             n_cam = np.zeros_like(cams[0])
#             n_vec = np.zeros_like(vecs[0])
#         cam2.append(n_cam)
#         vec2.append(n_vec)

#         r = rewards.get(aid, 0.0) + infos.get(aid, {}).get('group_reward', 0.0)
        
#         goal_reached += 1 if r > 19.9 else 0.0
#         crashes += 1 if r < -9.9 else 0.0
        
#         r_list.append(r)
#         d_list.append(done_flags.get(aid, False))

#     replay_buffer.store_joint(
#         np.stack(cams), np.stack(vecs), 
#         act_np,
#         np.array(r_list, dtype=np.float32).reshape(len(active_agents), 1),
#         np.stack(cam2), np.stack(vec2),
#         np.array(d_list, dtype=np.float32).reshape(len(active_agents), 1),
#     )
#     mean_r = np.mean(r_list)
#     ema_reward = (1.0 - ema_alpha) * ema_reward + ema_alpha * np.mean(r_list)

#     if steps % 4096 == 0:
#         a_loss, c_loss = agent.train(replay_buffer, step=steps)
    
#         if c_loss > 1e6:
#             agent.load("SavedModels/SAC_distilled_trained.pth")
#             print("Critic loss too high, reloading model")
    
#         if last_ema_reward < ema_reward:
#             last_ema_reward = ema_reward
#             agent.save("SavedModels/SAC_distilled_trained.pth")
#             print("Saving model with mean reward", ema_reward)

    
            
#         wandb.log({
#             "EMA Reward": ema_reward,
#             "Mean Reward": mean_r,
#             "Mean Reward p5/50/95": np.percentile(r_list, [5, 50, 95]).round(3),
#             "Actor Loss": a_loss,
#             "Critic Loss": c_loss,
#             "Steps": steps,
#             "Goal Reached": goal_reached,
#             "Crashes": crashes,
#         }, step=steps)
    
#     if cc_weights != agent.ccritic.q_heads[0].weight.abs().mean():
#         print("CC weights", agent.ccritic.q_heads[0].weight.abs().mean(), "and CC_prev", cc_weights)
#         cc_weights = agent.ccritic.q_heads[0].weight.abs().mean()

#     if steps % 10_000 == 0:
#         arr = np.array(r_list)
#         print("Reward p5/50/95:",
#             *np.percentile(arr, [5, 50, 95]).round(3))

#     obs = next_obs

# env.close()