In [1]:
import os, math, random, time, json, pathlib
from pathlib import Path
from typing import Tuple, Dict, Any
import numpy as np

import torch
import torch.nn.functional as F

SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED); torch.cuda.manual_seed_all(SEED)

torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False
torch.set_float32_matmul_precision("high")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

torch.autograd.set_detect_anomaly(False)

ENABLE_COMPILE = False

print("Device:", device)

Device: cuda


In [2]:
from Hyperparameters import SAC_DISTILLED as H
from SAC_Distillation.DistilledSACAgent import DistilledSAC
from SAC_Distillation.TeacherModel import TeacherModel
from SAC_Distillation.Trajectories import SAC_ExperienceBuffer

SHORT_RUN = False
if SHORT_RUN:
    H["train_steps"] = 150_000
    H["warmup_steps"] = 10_000




In [3]:
H.update({
    "use_drq": True,
    "drq_pad": 4,
    "intrinsic_reward_coef":        0.5,
    "intrinsic_reward_coef_final":  0.15,
    "intrinsic_coef_decay_steps":   1_000_000,
    "warmup_steps": 100_000,
})
print(json.dumps(H, indent=2))

{
  "gamma": 0.99,
  "gae_lambda": 0.95,
  "tau": 0.005,
  "actor_lr": 0.0003,
  "critic_lr": 0.0003,
  "alpha_lr": 0.0003,
  "distill_lr": 0.0001,
  "rnd_lr": 5e-05,
  "critic_updates": 3,
  "actor_updates": 2,
  "distill_coef": 0.06,
  "distill_epochs": 2,
  "distill_batch": 256,
  "distill_temp": 0.07,
  "distill_frames": 50000,
  "buffer_size": 1000000,
  "batch_size": 512,
  "seed_episodes": 1,
  "n_steps_random_exploration": 10000,
  "noise_std": 0.2,
  "smooth_clip": 0.2,
  "rnd_update_proportion": 0.05,
  "intrinsic_reward_coef": 0.5,
  "intrinsic_reward_coef_final": 0.15,
  "intrinsic_coef_decay_steps": 1000000,
  "extrinsic_reward_coef": 1.0,
  "max_steps": 5000000,
  "policy_delay": 2,
  "ema_alpha": 0.01,
  "curiosity_coeff": 0.1,
  "warmup_steps": 100000,
  "use_drq": true,
  "drq_pad": 4
}


In [4]:
import numpy as np
from gymnasium import spaces

# ---------- shared helpers ----------
def _stable_unique(seq):
    return list(dict.fromkeys(seq))

def _to_chw01(cam: np.ndarray) -> np.ndarray:
    if cam.ndim != 3:
        raise AssertionError(f"Camera observation should be 3D (HWC or CHW), got {cam.shape}")
    if cam.shape[-1] in (1, 3, 4):  # HWC -> CHW
        cam = np.transpose(cam, (2, 0, 1))
    cam = cam.astype(np.float32, copy=False)
    if cam.max() > 1.5:
        cam = cam / 255.0
    return cam

def _to_vec1d(vec: np.ndarray) -> np.ndarray:
    vec = np.asarray(vec)
    if vec.ndim > 1:
        vec = vec.reshape(-1)
    return vec.astype(np.float32, copy=False)

def _extract_agent_obs(obs, agent, *, cam_key=1, vec_keys=(0,2)):
    if agent not in obs:
        raise KeyError(f"Agent {agent!r} not in obs keys: {list(obs.keys())[:8]}...")
    data = obs[agent]
    if isinstance(data, dict) and "observation" in data:
        data = data["observation"]

    if isinstance(data, dict) and ("camera_obs" in data and "vector_obs" in data):
        cam = np.asarray(data["camera_obs"])
        vec = np.asarray(data["vector_obs"])
    else:
        cam = np.asarray(data[cam_key])
        parts = []
        for k in vec_keys:
            v = np.asarray(data[k])
            if v.ndim > 1:
                v = v.reshape(-1)
            parts.append(v)
        vec = np.concatenate(parts, axis=0) if len(parts) > 1 else parts[0]

    cam = _to_chw01(cam)
    vec = _to_vec1d(vec)
    return cam, vec

def _stack_from_per_agent(obs_dict, agent_ids, *, cam_key=1, vec_keys=(0,2)):
    cams, vecs = [], []
    for aid in agent_ids:
        cam, vec = _extract_agent_obs(obs_dict, aid, cam_key=cam_key, vec_keys=vec_keys)
        cams.append(cam); vecs.append(vec)
    camera_obs = np.stack(cams, axis=0)  # [B,C,H,W]
    maxd = max(v.shape[0] for v in vecs)
    vecs = [np.pad(v, (0, maxd - v.shape[0])) if v.shape[0] != maxd else v for v in vecs]
    vector_obs = np.stack(vecs, axis=0)  # [B,D]
    return {"camera_obs": camera_obs, "vector_obs": vector_obs}

def _to_rew(x, agent_ids):
    if isinstance(x, dict):
        arr = np.asarray([x.get(a, 0.0) for a in agent_ids], dtype=np.float32)
    elif isinstance(x, (list, tuple, np.ndarray)):
        arr = np.asarray(x, dtype=np.float32)
    else:
        arr = np.asarray([x], dtype=np.float32)
    return arr.reshape(-1, 1)

def _to_done_dict_false(agent_ids):
    return {a: False for a in agent_ids}

def _to_done(x, agent_ids):
    if isinstance(x, dict):
        arr = np.asarray([x.get(a, False) for a in agent_ids], dtype=bool)
    elif isinstance(x, (list, tuple, np.ndarray)):
        arr = np.asarray(x, dtype=bool)
    else:
        arr = np.asarray([x], dtype=bool)
    return arr.reshape(-1, 1)

# ---------- UnityParallelEnv adapter (PettingZoo Parallel) ----------
class ParallelMultiDroneAdapter:
    """
    Adapter for your UnityParallelEnv (PettingZoo Parallel).
    Exposes a Gym-like API:
      reset() -> (batched_obs, {})
      step(batched_actions) -> (batched_obs, rewards[B,1], terminated[B,1], truncated[B,1], info)
    """
    def __init__(self, env, *, cam_key=1, vec_keys=(0,2)):
        self.env = env
        self.cam_key = cam_key
        self.vec_keys = vec_keys
        self.agent_ids = None
        self.num_agents = None
        self._batched_action_space = None

    # --- action space resolution (must be continuous Box for SAC) ---
    def _resolve_single_action_space(self, agent_id):
        # Try common patterns
        sp = None
        if hasattr(self.env, "action_space") and callable(getattr(self.env, "action_space", None)):
            sp = self.env.action_space(agent_id)
        elif hasattr(self.env, "action_space") and isinstance(self.env.action_space, dict):
            sp = self.env.action_space.get(agent_id)
        elif hasattr(self.env, "action_spaces"):
            sp = self.env.action_spaces.get(agent_id)

        if isinstance(sp, spaces.Box):
            return sp

        if isinstance(sp, spaces.Dict):
            boxes = [v for v in sp.spaces.values() if isinstance(v, spaces.Box)]
            if len(boxes) == 1:
                return boxes[0]
            raise AssertionError(
                f"Dict action space has {len(sp.spaces)} entries ({len(boxes)} Box). "
                "Expose a single continuous Box or implement concatenation."
            )

        if isinstance(sp, spaces.Tuple):
            boxes = [s for s in sp.spaces if isinstance(s, spaces.Box)]
            if len(sp.spaces) == 1 and len(boxes) == 1:
                return boxes[0]
            raise AssertionError(
                f"Tuple action space has {len(sp.spaces)} entries (Box count: {len(boxes)}). "
                "Expose a single continuous Box or implement concatenation."
            )

        if isinstance(sp, (spaces.Discrete, spaces.MultiDiscrete, spaces.MultiBinary)) or sp is None:
            raise AssertionError(
                f"SAC requires a continuous Box action space, got {type(sp).__name__ if sp is not None else 'None'}."
            )

        raise AssertionError(f"Unsupported per-agent action space type: {type(sp).__name__}")

    def _get_batched_action_space(self):
        if self.agent_ids is None:
            # Need one reset to know agents
            _ = self.reset()
        first = self.agent_ids[0]
        box = self._resolve_single_action_space(first)

        low  = np.repeat(box.low[None, ...], self.num_agents, axis=0)
        high = np.repeat(box.high[None, ...], self.num_agents, axis=0)

        batched = spaces.Box(low=low, high=high, dtype=box.dtype)
        batched.shape = (self.num_agents, box.shape[0])

        def _sample():
            return np.stack([box.sample() for _ in range(self.num_agents)], axis=0)
        batched.sample = _sample
        return batched

    # Allow both callable and property usage
    def action_space(self):
        return self._get_batched_action_space()
    @property
    def action_space_property(self):
        return self._get_batched_action_space()

    # --- Gym-like API built on your UnityParallelEnv signatures ---
    def reset(self, **kwargs):
        # Your UnityParallelEnv.reset() -> observations (dict), no info
        obs = self.env.reset(**kwargs)
        if not isinstance(obs, dict):
            raise TypeError("UnityParallelEnv.reset() must return a dict of per-agent observations.")
        if self.agent_ids is None:
            self.agent_ids = _stable_unique(list(obs.keys())) if len(obs) > 0 else _stable_unique(getattr(self.env, "agents", []))
            self.num_agents = len(self.agent_ids)

        batched = _stack_from_per_agent(obs, self.agent_ids, cam_key=self.cam_key, vec_keys=self.vec_keys)
        return batched, {}  # empty info for Gym-compat

    def step(self, action_batched):
        # Map (B, act_dim) -> per-agent dict
        if isinstance(action_batched, dict):
            action_dict = action_batched
        else:
            action_batched = np.asarray(action_batched)
            assert action_batched.shape[0] == self.num_agents, f"Expected {self.num_agents} actions, got {action_batched.shape}"
            action_dict = {aid: action_batched[i] for i, aid in enumerate(self.agent_ids)}

        # Your UnityParallelEnv.step(actions) -> (obs, rewards, dones, infos)
        next_obs, rewards, dones, infos = self.env.step(action_dict)

        if not isinstance(next_obs, dict):
            raise TypeError("UnityParallelEnv.step() must return per-agent observation dict.")

        batched_obs = _stack_from_per_agent(next_obs, self.agent_ids, cam_key=self.cam_key, vec_keys=self.vec_keys)
        rew  = _to_rew(rewards, self.agent_ids)     # (B,1)
        term = _to_done(dones,   self.agent_ids)    # (B,1)
        # UnityParallelEnv has no truncations -> synthesize all False
        trunc = _to_done(_to_done_dict_false(self.agent_ids), self.agent_ids)

        # infos is already a dict (per agent / global); pass it through
        return batched_obs, rew, term, trunc, infos

    # passthroughs
    def close(self): 
        return getattr(self.env, "close", lambda: None)()
    def render(self, *a, **k):
        return getattr(self.env, "render", lambda *a, **k: None)(*a, **k)


In [5]:
def relocate_agents(env):
    """
    Return a stable, deduplicated list of agent IDs.
    Works for:
      - PettingZoo-style envs with `env.agents`
      - Unity/Gym wrappers exposing `num_agents` or agent names
      - Falls back to a range of indices if no IDs are present
    """
    # PettingZoo-style
    if hasattr(env, "agents") and isinstance(getattr(env, "agents"), (list, tuple)):
        return list(dict.fromkeys(env.agents))  # stable-unique

    # Some Unity wrappers expose number of agents
    if hasattr(env, "num_agents") and isinstance(env.num_agents, (int, np.integer)):
        return list(range(int(env.num_agents)))

    # Try to infer from observation space if available
    if hasattr(env, "observation_space") and hasattr(env.observation_space, "n"):
        return list(range(int(env.observation_space.n)))

    # Last resort: single agent
    return [0]

In [6]:
import mlagents
from mlagents_envs.environment import UnityEnvironment as UE
from mlagents_envs.envs.unity_parallel_env import UnityParallelEnv as UPZBE

env = UE(file_name="Env/Level1/DroneFlightv1", seed=1)
env = UPZBE(env)
env = ParallelMultiDroneAdapter(env)

In [8]:
obs, info = env.reset()
obs.keys()

dict_keys(['camera_obs', 'vector_obs'])

# TeacherPretraining

In [9]:
teacher = TeacherModel().to(device)
if ENABLE_COMPILE:
    try:
        teacher = torch.compile(teacher, mode="reduce-overhead", fullgraph=False)
    except Exception as e:
        print("compile skipped:", e)

opt_t = torch.optim.AdamW(teacher.head.parameters(), lr=1e-3, weight_decay=1e-5)

def gather_random_frames(env, n_steps=2000):
    assert env is not None, "Instantiate env before running"
    frames = []
    with torch.no_grad():
        obs = env.reset()
        for _ in range(n_steps):
            a = env.action_space.sample()
            o2, r, done, info = env.step(a)
            cam = o2["camera_obs"]
            frames.append(torch.from_numpy(cam))
            if np.any(done):
                obs = env.reset()
            else:
                obs = o2

    return torch.cat(frames, dim=0)

In [11]:
print(env.action_space().shape)

AssertionError: Unsupported per-agent action space type: Box

In [None]:
replay = SAC_ExperienceBuffer(camera_obs_dim=obs["camera_obs"].shape, vector_obs_dim=obs["vector_obs"].shape,action_dim=env., params=H)


AssertionError: Only Box action spaces are supported in this adapter.