In [1]:
import os, random
import numpy as np
import mlagents
from mlagents_envs.environment import UnityEnvironment as UE
from mlagents_envs.envs.unity_parallel_env import UnityParallelEnv as UPZBE
from SAC_Distillation.DistilledSACAgent import DistilledSAC
from SAC_Distillation.Trajectories import SAC_ExperienceBuffer
from Hyperparameters import HYPERPARAMS as params
import numpy as np
import torch
import wandb

SEED = 42
random.seed(SEED); np.random.seed(SEED)
torch.manual_seed(SEED); torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False
torch.set_float32_matmul_precision("high")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



In [2]:
def relocate_agents(env):
    return sorted(list(env.agents))

# New helper to extract observation data for an agent
def get_agent_obs(obs, agent, *, cam_key=1, vec_keys=[0,2]):
    if agent not in obs:
        raise KeyError(f"Agent {agent!r} not found in obs (keys: {list(obs.keys())[:8]}...)")

    data = obs[agent]
    if isinstance(data, dict) and "observation" in data:
        data = data["observation"]

    # Case A: explicit keys
    if isinstance(data, dict) and ("camera_obs" in data and "vector_obs" in data):
        cam = np.asarray(data["camera_obs"])
        vec = np.asarray(data["vector_obs"])
        # ensure vec is 1D
        if vec.ndim > 1:
            vec = vec.reshape(-1)
    else:
        # Case B: indexed container
        cam = np.asarray(data[cam_key])
        v0 = np.asarray(data[vec_keys[0]]).reshape(-1)
        v1 = np.asarray(data[vec_keys[1]]).reshape(-1)
        vec = np.concatenate([v0, v1], axis=0)

    # ---- Camera post-processing: to CHW float32 in [0,1] ----
    if cam.ndim != 3:
        raise AssertionError(f"Camera observation must be 3D (HWC or CHW), got shape {cam.shape}")

    # If HWC (channel last), move to CHW
    if cam.shape[-1] in (1, 3, 4):
        cam = np.transpose(cam, (2, 0, 1))

    cam = cam.astype(np.float32, copy=False)
    if cam.max() > 1.5:  # likely uint8 [0..255]
        cam = cam / 255.0

    vec = vec.astype(np.float32, copy=False)

    return cam, vec

In [3]:
env = UE(file_name="Env/Level1/DroneFlightv1", seed=1)
env = UPZBE(env)

In [4]:
agents = relocate_agents(env)
N_AGENTS = len(agents)

In [5]:
obs = env.reset()
cam_shape = env.observation_space(agents[0])[1].shape
vec_dim = env.observation_space(agents[0])[0].shape[0] + env.observation_space(agents[0])[2].shape[0]
vec_shape = (vec_dim,)
action_shape = env.action_space(agents[0]).shape
print("Agents:", N_AGENTS, "| cam_shape:", cam_shape, "| vec_dim:", vec_shape[0], "| act:", action_shape)

Agents: 4 | cam_shape: (4, 84, 84) | vec_dim: 84 | act: (4,)


In [6]:
replay_buffer = SAC_ExperienceBuffer(cam_shape, vec_shape,action_shape, params['sac_distilled'])

In [7]:
agent = DistilledSAC(cam_shape, vec_shape, action_shape,len(agents), params['sac_distilled'])

feat_path = "SavedModels/feature_extractor_contrastive_init.pth"
if os.path.isfile(feat_path):
    state = torch.load(feat_path, map_location=device)
    agent.model.convolution_pipeline.load_state_dict(state, strict=False)
else:
    print(f"⚠️  Init features not found: {feat_path}")

In [8]:
cfg = params['sac_distilled']
RAND_STEPS = cfg.get("n_steps_random_exploration", 100_000)
SEED_EPISODES = cfg.get("seed_episodes", 1)
blank_cam = np.zeros(cam_shape, dtype=np.float32)
blank_vec = np.zeros(vec_shape, dtype=np.float32)

In [9]:
import datetime as dt
run_name = f"sac_distill_{dt.datetime.now():%Y%m%d_%H%M%S}_level_1"
wandb.init(
            project=os.getenv("WANDB_PROJECT", "SAC_Distillation"),
            entity =os.getenv("WANDB_ENTITY",  "fede-"),
            name=run_name,
            config = {**params["sac_distilled"], "device": str(device)},
        )

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mrullofederico16[0m ([33mfede-[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [10]:
# --- Config / counters ---
total_updates = 0
train_every   = 4096
log_every     = 4096
print_every   = 10000
max_steps     = cfg.get("max_steps", 5_000_000)

ema_reward      = 0.0
last_ema_reward = -np.inf
ema_alpha       = cfg.get("ema_alpha", 0.01)

# --- Reset & stable agent order (keep constant over run) ---
obs   = env.reset()
AGENTS = relocate_agents(env)            # must preserve order
N_AGENTS = len(AGENTS)

# Expect get_agent_obs to return CHW float32 in [0,1]; vec 1D float32
blank_cam = np.zeros((cam_shape), dtype=np.float32)
blank_vec = np.zeros((vec_shape), dtype=np.float32)

# --- Preallocate step buffers (reuse each step) ---
cam_now   = np.empty((N_AGENTS, *cam_shape), dtype=np.float32)
vect_now  = np.empty((N_AGENTS, *vec_shape), dtype=np.float32)
cam_next  = np.empty_like(cam_now)
vect_next = np.empty_like(vect_now)
rew_now   = np.empty((N_AGENTS, 1), dtype=np.float32)
done_now  = np.zeros((N_AGENTS, 1), dtype=np.float32)   # continuous env: always zeros

steps = 0
goal_reached = 0.0
crashed = 0.0


# Mapping from canonical AGENTS -> index
agent_to_idx = {aid: i for i, aid in enumerate(AGENTS)}

obs = env.reset()

In [11]:
while steps < max_steps:
    # If no live agents (obs empty), reset and continue
    if not obs or (isinstance(obs, dict) and len(obs) == 0):
        obs = env.reset()
        continue

    # Live agents this step (use obs keys)
    live = relocate_agents(env)

    # --- PACK CURRENT OBS INTO FIXED SLOTS ---
    # Start with blanks
    cam_now[:] = blank_cam
    vect_now[:]  = blank_vec
    for aid in live:
        i = agent_to_idx.get(aid, None)
        if i is None:
            # New agent id we haven't seen before -> extend canonical arrays (rare)
            # Simple choice: skip it (or rebuild AGENTS/arrays if you prefer)
            continue
        cam, vec = get_agent_obs(obs, aid)  # cam: (C,H,W) float[0,1], vec: (D,)
        cam_now[i]  = cam
        vect_now[i] = vec

    cam_t = torch.from_numpy(cam_now).to(device)
    vec_t = torch.from_numpy(vect_now).to(device)
    if torch.isnan(cam_t).any() or torch.isnan(vec_t).any():
        cam_t = torch.nan_to_num(cam_t); vec_t = torch.nan_to_num(vec_t)

    with torch.no_grad():
        act_t, _ = agent.get_action(cam_t, vec_t, train=True)  # (N_AGENTS, act_dim)
    act_np = act_t.detach().cpu().numpy()
    np.clip(act_np, -1.0, 1.0, out=act_np)

    # --- BUILD ACTIONS ONLY FOR LIVE AGENTS ---
    actions = {aid: act_np[agent_to_idx[aid]] for aid in live}

    # If somehow live is empty, reset and continue
    if len(actions) == 0:
        obs = env.reset()
        continue

    # --- STEP ENV ---
    next_obs, rew_dict, done_dict, infos = env.step(actions)
    steps += 1

    # --- PACK NEXT OBS & REWARDS INTO FIXED SLOTS ---
    cam_next[:]  = blank_cam
    vect_next[:] = blank_vec
    rew_now[:]   = 0.0
    # done_now stays 0 in continuous env

    # Note: next_obs may also have fewer agents; fill from those keys
    next_live = list(next_obs.keys()) if isinstance(next_obs, dict) else []
    for aid in next_live:
        i = agent_to_idx.get(aid, None)
        if i is None:
            continue
        cam_n, vec_n = get_agent_obs(next_obs, aid)
        cam_next[i]  = cam_n
        vect_next[i] = vec_n

    # Rewards (use the agents we acted for this step)
    for aid in live:
        i = agent_to_idx.get(aid, None)
        if i is None:
            continue
        r = float(rew_dict.get(aid, 0.0)) + float(infos.get(aid, {}).get("reward", 0.0))
        rew_now[i, 0] = r
        # event counters (optional thresholds)
        if r > 70:   goal_reached += 1
        if r < -99:  crashed += 1

    # Store transition
    priority = 3.0 if (rew_now.max() >= 70.0) else 1.0
    replay_buffer.store_joint(
        cam_now, vect_now, act_np,
        rew_now,
        cam_next, vect_next,
        done_now,                  # all zeros; continuous env
        priority=priority
    )

    # --- TRAIN / LOG (unchanged, with safe conversions) ---
    mean_r     = float(np.mean(rew_now))
    ema_reward = ema_reward * (1.0 - ema_alpha) + mean_r * ema_alpha

    curr_size = replay_buffer.size() if callable(getattr(replay_buffer, "size", None)) else replay_buffer.size
    if curr_size < (agent.batch_size * max(1, getattr(agent, "num_agents", N_AGENTS))):
        if steps % print_every == 0:
            print(f"Steps: {steps} | EMA Reward: {ema_reward:.3f} | Buffer: {curr_size}", end="\r")
        obs = next_obs
        continue

    if steps % train_every == 0:
        a_loss, c_loss, intrinsic_rew, rnd_loss = agent.train(replay_buffer, step=steps)
        total_updates += 1
        c_loss_val = float(c_loss.detach().cpu()) if isinstance(c_loss, torch.Tensor) else float(c_loss)
        if c_loss_val > 1e6:
            try:
                agent.load("SavedModels/SAC_distilled_trained_level_1.pth")
                print("\n[Guard] Critic loss exploded; reloaded last checkpoint.", end="\r")
            except Exception:
                print("\n[Guard] Critic loss exploded; no checkpoint to load.", end="\r")

        if ema_reward > last_ema_reward:
            last_ema_reward = ema_reward
            try:
                agent.save("SavedModels/SAC_distilled_trained_level_1.pth")
                print(f"\n[Checkpoint] New best EMA reward: {last_ema_reward:.2f}", end="\r")
            except Exception:
                pass

        if steps % log_every == 0:
            print(f"\nStep: {steps} | Goals: {goal_reached:.0f} | Crashes: {crashed:.0f}")
            print(f"Losses: Actor={float(a_loss):.4f} | Critic={c_loss_val:.4f} | RND={float(rnd_loss):.4f} | Intr={float(intrinsic_rew):.4f}",end="\r")
            print(f"Mean R: {mean_r:.4f} | EMA R: {ema_reward:.4f}", end="\r")
            if 'wandb' in globals():
                wandb.log({
                    "EMA Reward": float(ema_reward),
                    "Mean Reward": float(mean_r),
                    "Actor Loss": float(a_loss),
                    "Critic Loss": c_loss_val,
                    "Intrinsic Reward": float(intrinsic_rew),
                    "RND Loss": float(rnd_loss),
                    "Steps": int(steps),
                    "Goal Events": float(goal_reached),
                    "Crash Events": float(crashed),
                    "Updates": int(total_updates),
                }, step=int(steps))
            goal_reached = 0.0
            crashed = 0.0

    # Keep streaming (continuous env)
    obs = next_obs




[Checkpoint] New best EMA reward: -4.41
Step: 4096 | Goals: 0 | Crashes: 288
Mean R: -0.0000 | EMA R: -4.41180000 | RND=0.0000 | Intr=0.0000
[Checkpoint] New best EMA reward: -4.22
Step: 8192 | Goals: 0 | Crashes: 285
Mean R: -3.7332 | EMA R: -4.21680000 | RND=0.0000 | Intr=0.0000
Step: 12288 | Goals: 0 | Crashes: 295
Mean R: -3.0059 | EMA R: -4.22350000 | RND=0.0000 | Intr=0.0000
Step: 16384 | Goals: 0 | Crashes: 287
Mean R: -2.0034 | EMA R: -4.56940000 | RND=0.0000 | Intr=0.0000
[Checkpoint] New best EMA reward: -4.16
Step: 20480 | Goals: 0 | Crashes: 286
Mean R: -3.3602 | EMA R: -4.15730000 | RND=0.0000 | Intr=0.0000
[Checkpoint] New best EMA reward: -4.06
Step: 24576 | Goals: 0 | Crashes: 293
Mean R: -26.0264 | EMA R: -4.0587000 | RND=0.0000 | Intr=0.0000
Step: 28672 | Goals: 0 | Crashes: 290
Mean R: -2.5717 | EMA R: -4.49610000 | RND=0.0000 | Intr=0.0000
[Checkpoint] New best EMA reward: -3.91
Step: 32768 | Goals: 0 | Crashes: 286
Mean R: -2.0696 | EMA R: -3.90640000 | RND=0.0000




Step: 102400 | Goals: 0 | Crashes: 280
Mean R: -2.3388 | EMA R: -3.9040.7553 | RND=0.0132 | Intr=0.0130
Step: 106496 | Goals: 0 | Crashes: 283
Mean R: -2.2744 | EMA R: -4.2834.2062 | RND=0.0123 | Intr=0.0125
Step: 110592 | Goals: 0 | Crashes: 281
Mean R: -2.9183 | EMA R: -4.0486.8580 | RND=0.0125 | Intr=0.0125
Step: 114688 | Goals: 0 | Crashes: 277
Mean R: -1.8780 | EMA R: -4.1089.1798 | RND=0.0119 | Intr=0.0121
Step: 118784 | Goals: 0 | Crashes: 310
Mean R: -1.4417 | EMA R: -4.7544.1406 | RND=0.0370 | Intr=0.0370
Step: 122880 | Goals: 0 | Crashes: 366
Mean R: -3.2445 | EMA R: -4.87202.6954 | RND=0.0392 | Intr=0.0392
Step: 126976 | Goals: 0 | Crashes: 356
Mean R: -3.2875 | EMA R: -4.8083.9194 | RND=0.0367 | Intr=0.0367
Step: 131072 | Goals: 0 | Crashes: 285
Mean R: -3.7609 | EMA R: -4.4344.0405 | RND=0.0387 | Intr=0.0390
Step: 135168 | Goals: 0 | Crashes: 254
Mean R: -3.7476 | EMA R: -4.2121.7235 | RND=0.0345 | Intr=0.0345
Step: 139264 | Goals: 0 | Crashes: 238
Mean R: -2.5197 | EMA R

KeyboardInterrupt: 

In [None]:
# Clean up
try:
    env.close()
except Exception:
    pass
agent.load(path="SavedModels/SAC_distilled_trained_level_1.pth")

In [None]:
env = UE(file_name="Env/FinalLevel/DroneFlightv1", seed=1)
env = UPZBE(env)

In [None]:
agents = list(set(env.agents))
print(agents)
print(env.action_spaces)
print(env._env.behavior_specs.values())

['Drone?team=0?agent_id=2', 'Drone?team=0?agent_id=0', 'Drone?team=0?agent_id=3', 'Drone?team=0?agent_id=1']
{'Drone?team=0?agent_id=2': Box(-1.0, 1.0, (4,), float32), 'Drone?team=0?agent_id=0': Box(-1.0, 1.0, (4,), float32), 'Drone?team=0?agent_id=3': Box(-1.0, 1.0, (4,), float32), 'Drone?team=0?agent_id=1': Box(-1.0, 1.0, (4,), float32)}
ValuesView(<mlagents_envs.base_env.BehaviorMapping object at 0x000001DB73913FA0>)


In [None]:
del replay_buffer
replay_buffer = SAC_ExperienceBuffer(cam_shape, vec_shape, action_shape, params['sac_distilled'])

In [None]:
tot_steps = RAND_STEPS * SEED_EPISODES   # e.g. 5 120 000

obs_dict = env.reset()                   # one reset BEFORE the loop

for step in range(tot_steps):
    if not obs_dict:
        obs_dict = env.reset()
        continue
    agents = relocate_agents(env)  # get the current agents in the environment
    # --- draw a random joint action ---------------------------------
    act_dict = {a: env.action_space(a).sample() for a in agents}

    cam_now  = np.empty((N_AGENTS, *cam_shape),   dtype=np.float32)
    vect_now = np.empty((N_AGENTS, *vec_shape),   dtype=np.float32)
    act_now  = np.empty((N_AGENTS, *action_shape),dtype=np.float32)

    for i, a in enumerate(agents):
        cam, vec = get_agent_obs(obs_dict, a) if a in obs_dict else (blank_cam, blank_vec)
        cam_now[i], vect_now[i], act_now[i] = cam, vec, act_dict[a]

    # --- take one step ----------------------------------------------
    next_obs, rew_dict, done_dict, _ = env.step(act_dict)

    # --- pack the next‐state tensors --------------------------------
    cam_next  = np.empty_like(cam_now)
    vect_next = np.empty_like(vect_now)
    rew_now   = np.zeros((N_AGENTS, 1), dtype=np.float32)
    done_now  = np.zeros((N_AGENTS, 1), dtype=np.float32)

    for i, a in enumerate(agents):
        cam_n, vec_n = get_agent_obs(next_obs, a) if a in next_obs else (blank_cam, blank_vec)
        cam_next[i], vect_next[i] = cam_n, vec_n
        rew_now[i, 0]  = rew_dict.get(a, 0.0)
        done_now[i, 0] = float(done_dict.get(a, False))

    replay_buffer.store_joint(cam_now, vect_now, act_now,
                              rew_now, cam_next, vect_next, done_now)

    obs_dict = next_obs

    # if the whole team is done, start a new episode
    if all(done_dict.values()):
        obs_dict = env.reset()

print("Finished collecting random steps")


Finished collecting random steps


In [None]:
import datetime as dt
run_name = f"sac_distill_{dt.datetime.now():%Y%m%d_%H%M%S}"
wandb.init(
            project=os.getenv("WANDB_PROJECT", "SAC_Distillation"),
            entity =os.getenv("WANDB_ENTITY",  "fede-"),
            name=run_name,
            config = {**params["sac_distilled"], "device": str(device)},
        )

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mrullofederico16[0m ([33mfede-[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
total_updates = 0
train_every = 4_096
seed_step_offline_distil = 1
log_every = 1000
print_every = 10_000
max_steps = cfg.get("max_steps", 5_000_000)
goal_reached = 0.0
crashed = 0.0


ema_reward = 0.0
last_ema_reward = -np.inf
ema_alpha = cfg.get("ema_alpha", 0.01)


obs = env.reset()
steps=0

In [None]:
if seed_step_offline_distil > 0:
    agent.offline_distill(replay_buffer,batch_size=256)


while steps < max_steps:
    if not obs:
        obs = env.reset()
        continue
    agents = relocate_agents(env)
    cam_now = np.zeros((N_AGENTS, *cam_shape), dtype=np.uint8)
    vect_now = np.zeros((N_AGENTS, *vec_shape), dtype=np.float32)

    for i, aid in enumerate(agents):
        if aid in obs:
            cam, vec = get_agent_obs(obs, aid)
        else:
            cam, vec = blank_cam, blank_vec
        cam_now[i]  = cam
        vect_now[i] = vec

    cam_t = torch.from_numpy(np.stack(cam_now)).float().to(device)
    vec_t = torch.from_numpy(np.stack(vect_now)).float().to(device)

    if cam_t.isnan().any() or vec_t.isnan().any():
        cam_t = torch.nan_to_num(cam_t)
        vec_t = torch.nan_to_num(vec_t)
    with torch.no_grad():
        step_fraction = steps / max_steps
        act_t = agent.get_action(cam_t, vec_t, train=False)
    

    # act_np = torch.round(act_t).clamp(-1,1).cpu().numpy()
    act_np = act_t.cpu().numpy()
    actions = {aid: action for aid, action in zip(agents, act_np)}
    # print(actions)
    

    next_obs, rew_dict, done_dict, infos = env.step(actions)
    steps += 1

    cam_next = np.zeros_like(cam_now)
    vect_next = np.zeros_like(vect_now)
    rew_now = np.zeros((N_AGENTS, 1), dtype=np.float32)
    done_now = np.zeros((N_AGENTS, 1), dtype=np.float32)

    

    for i, aid in enumerate(agents):
        if aid in next_obs:
            cam_n, vec_n = get_agent_obs(next_obs, aid)
        else:
            cam_n, vec_n = blank_cam, blank_vec
        cam_next[i] = cam_n
        vect_next[i] = vec_n

        r = rew_dict.get(aid, 0.0) + infos.get(aid, {}).get('reward', 0.0)
        rew_now[i, 0] = r
        done_now[i, 0] = done_dict.get(aid, False)

        goal_reached += 1 if r > 19 else 0
        crashed += 1 if r < -9 else 0

    replay_buffer.store_joint(
        cam_now, vect_now, act_np,
        rew_now,
        cam_next, vect_next,
        done_now
    )

    mean_r = np.mean(rew_now).item()
    ema_reward = ema_reward * (1- ema_alpha) + mean_r * ema_alpha
    
    if replay_buffer.size < agent.batch_size * agent.num_agents:
        continue

    if steps % train_every == 0:
        a_loss, c_loss, intrinsic_rew, rnd_loss = agent.train(replay_buffer, step=steps)
        total_updates += 1

        if c_loss > 1e6:
            agent.load("SavedModels/SAC_distilled_trained_fix.pth")
            print("Critic loss exploded, reloading model")
        
        if ema_reward > last_ema_reward:
            last_ema_reward = ema_reward
            agent.save("SavedModels/SAC_distilled_trained_fix.pth")
            print(f"New best EMA reward: {last_ema_reward:.2f}")

        if steps % log_every==0:
            wandb.log({
                "EMA Reward": ema_reward,
                "Mean Reward": mean_r,
                "Actor Loss": a_loss,
                "Critic Loss": c_loss,
                "Steps": steps,
                "Goal Reached": goal_reached,
                "Crashes": crashed,
                "Intrinsic Reward": intrinsic_rew,
                "RND Loss": rnd_loss,
            }, step=steps)
            
            goal_reached = 0.0
            crashed = 0.0

    obs = next_obs

env.close()

New best EMA reward: -0.17
New best EMA reward: -0.03
New best EMA reward: -0.00




New best EMA reward: 0.00
New best EMA reward: 0.02
New best EMA reward: 0.03
New best EMA reward: 0.03
New best EMA reward: 0.04
New best EMA reward: 0.05
New best EMA reward: 0.05
New best EMA reward: 0.06
New best EMA reward: 0.06
New best EMA reward: 0.07
New best EMA reward: 0.07


KeyboardInterrupt: 