In [7]:
!pip install -r requirements.txt



##### Import Libaries

In [8]:
import numpy as np
import gymnasium as gym
from gymnasium import Wrapper
from dataclasses import dataclass
from typing import Callable, Dict, Any, Tuple, List
import os
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
from stable_baselines3 import PPO, SAC, TD3
from stable_baselines3.common.evaluation import evaluate_policy

##### Task Variants for Cheetah (Target Velocity)

In [9]:
class HalfCheetahTargetVelocity(Wrapper):
    def __init__(self, env, target_velocity: float, vel_scale: float = 1.0, ctrl_cost_weight: float = 0.1):
        super().__init__(env)
        self.vt = float(target_velocity)
        self.vel_scale = float(vel_scale)
        self.ctrl_cost_weight = float(ctrl_cost_weight)

    def step(self, action):
        obs, _, terminated, truncated, info = self.env.step(action)

        vx = float(self.env.unwrapped.data.qvel[0])  # forward velocity
        vel_reward = -abs(vx - self.vt) * self.vel_scale
        ctrl_cost = self.ctrl_cost_weight * float(np.sum(action**2))
        reward = vel_reward - ctrl_cost

        info = dict(info)
        info.update({"vx": vx, "target_v": self.vt})
        return obs, reward, terminated, truncated, info


##### Task Variants for Ant (Target Direction)

In [10]:
class AntTargetDirection(Wrapper):
    def __init__(self, env, direction: np.ndarray, ctrl_cost_weight: float = 0.05):
        super().__init__(env)
        d = np.asarray(direction, dtype=np.float32)
        self.dir = d / (np.linalg.norm(d) + 1e-8)
        self.ctrl_cost_weight = float(ctrl_cost_weight)
        self._prev_xy = None

    def reset(self, **kwargs):
        obs, info = self.env.reset(**kwargs)
        self._prev_xy = np.array(self.env.unwrapped.data.qpos[0:2], dtype=np.float32)
        return obs, info

    def step(self, action):
        obs, _, terminated, truncated, info = self.env.step(action)

        xy = np.array(self.env.unwrapped.data.qpos[0:2], dtype=np.float32)
        dt = float(self.env.unwrapped.dt)
        vel_xy = (xy - self._prev_xy) / max(dt, 1e-8)
        self._prev_xy = xy

        dir_speed = float(np.dot(vel_xy, self.dir))
        ctrl_cost = self.ctrl_cost_weight * float(np.sum(action**2))
        reward = dir_speed - ctrl_cost

        info = dict(info)
        info.update({"vel_xy": vel_xy, "dir": self.dir, "dir_speed": dir_speed})
        return obs, reward, terminated, truncated, info


##### Task Variants Walker (Target Height)

In [11]:
class Walker2dTargetHeight(Wrapper):
    def __init__(self, env, target_height: float, height_scale: float = 1.0, ctrl_cost_weight: float = 0.001):
        super().__init__(env)
        self.ht = float(target_height)
        self.height_scale = float(height_scale)
        self.ctrl_cost_weight = float(ctrl_cost_weight)

    def step(self, action):
        obs, _, terminated, truncated, info = self.env.step(action)

        height = float(self.env.unwrapped.data.qpos[1])  # torso height (typ.)
        height_reward = -abs(height - self.ht) * self.height_scale
        ctrl_cost = self.ctrl_cost_weight * float(np.sum(action**2))
        reward = height_reward - ctrl_cost

        info = dict(info)
        info.update({"height": height, "target_h": self.ht})
        return obs, reward, terminated, truncated, info

In [12]:
def task_base(env_id: str, seed: int = 0):
    env = gym.make(env_id)
    env.reset(seed=seed)
    return env

def task_halfcheetah_target_velocity(target_v: float, seed: int = 0):
    env = gym.make("HalfCheetah-v4")
    env = HalfCheetahTargetVelocity(env, target_velocity=target_v)
    env.reset(seed=seed)
    return env

def task_ant_target_direction(dx: float, dy: float, seed: int = 0):
    env = gym.make("Ant-v4")
    env = AntTargetDirection(env, direction=np.array([dx, dy], dtype=np.float32))
    env.reset(seed=seed)
    return env

def task_walker2d_target_height(target_h: float, seed: int = 0):
    env = gym.make("Walker2d-v4")
    env = Walker2dTargetHeight(env, target_height=target_h)
    env.reset(seed=seed)
    return env


In [13]:
@dataclass(frozen=True)
class Task:
    name: str
    make_env: Callable[[], gym.Env]


tasks: List[Task] = [
    Task("BASE HalfCheetah", lambda: task_base("HalfCheetah-v4", seed=0)),
    Task("BASE Hopper",      lambda: task_base("Hopper-v4", seed=0)),
    Task("BASE Walker2d",    lambda: task_base("Walker2d-v4", seed=0)),
    Task("BASE Ant",         lambda: task_base("Ant-v4", seed=0))

    #Task("HC target_v=1.0",  lambda: task_halfcheetah_target_velocity(1.0, seed=1)),
    #Task("HC target_v=2.0",  lambda: task_halfcheetah_target_velocity(2.0, seed=2)),

    #Task("Ant EAST",         lambda: task_ant_target_direction(1.0, 0.0, seed=3)),
    #Task("Ant NORTH",        lambda: task_ant_target_direction(0.0, 1.0, seed=4)),

    #Task("Walker h=1.2",     lambda: task_walker2d_target_height(1.2, seed=5)),
    #Task("Walker h=1.6",     lambda: task_walker2d_target_height(1.6, seed=6)),
]


In [14]:
def build_vec_env(task: Task, seed: int = 0, normalize_obs: bool = True):
    def _init():
        env = task.make_env()
        env = Monitor(env)  # logs episode returns/lengths
        env.reset(seed=seed)
        return env

    venv = DummyVecEnv([_init])

    if normalize_obs:
        venv = VecNormalize(venv, norm_obs=True, norm_reward=False, clip_obs=10.0)

    return venv


In [15]:
def make_teacher(algo: str, env, seed: int = 0, logdir: str = None):
    algo = algo.upper()
    common = dict(verbose=1, seed=seed, tensorboard_log=logdir)

    if algo == "SAC":
        return SAC("MlpPolicy", env, batch_size=256, learning_rate=3e-4, gamma=0.99, **common)

    if algo == "TD3":
        return TD3("MlpPolicy", env, batch_size=256, learning_rate=1e-3, gamma=0.99, **common)

    if algo == "PPO":
        return PPO("MlpPolicy", env, n_steps=2048, batch_size=64, learning_rate=3e-4, gamma=0.99, **common)

    raise ValueError(f"Unknown algo: {algo}")


def train_teacher_for_task(
    task: Task,
    algo: str = "SAC",
    total_timesteps: int = 300_000,
    seed: int = 0,
    normalize_obs: bool = True,
    out_dir: str = "./teachers",
    log_dir: str = "./tb_logs",
):
    os.makedirs(out_dir, exist_ok=True)
    os.makedirs(log_dir, exist_ok=True)

    # Build env
    venv = build_vec_env(task, seed=seed, normalize_obs=normalize_obs)

    # Train teacher
    model = make_teacher(algo, venv, seed=seed, logdir=log_dir)
    model.learn(total_timesteps=total_timesteps, progress_bar=True)

    # Evaluate (freeze normalization updates)
    venv.training = False
    venv.norm_reward = False

    mean_r, std_r = evaluate_policy(model, venv, n_eval_episodes=10, deterministic=True)
    print(f"[{task.name}] {algo} eval: {mean_r:.2f} +/- {std_r:.2f}")

    # Save
    model_path = os.path.join(out_dir, f"{task.name}_{algo}.zip")
    model.save(model_path)

    vec_path = None
    if isinstance(venv, VecNormalize):
        vec_path = os.path.join(out_dir, f"{task.name}_{algo}_vecnormalize.pkl")
        venv.save(vec_path)

    venv.close()
    return {"task": task.name, "algo": algo, "mean": mean_r, "std": std_r, "model_path": model_path, "vec_path": vec_path}



In [16]:
results = []

for i, task in enumerate(tasks):
    res = train_teacher_for_task(
        task=task,
        algo="SAC",
        total_timesteps=500000,
        seed=100 + i,
        normalize_obs=True,
        out_dir="./teachers",
        log_dir="./tb_logs",
    )
    results.append(res)

## tensorboard --logdir tb_logs


KeyboardInterrupt: 

In [None]:
results

##### Load Teacher for Memory Creation

In [None]:
import os
import numpy as np
import gymnasium as gym

from dataclasses import dataclass
from typing import Callable, Optional, Dict, Any, List

from stable_baselines3 import SAC
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize

def load_sac_teacher(task: Task, model_path: str, vec_path: Optional[str], seed: int = 0):
    """
    Loads SAC model + VecNormalize stats (if provided) for correct obs normalization.
    Returns (model, venv) where venv is ready for inference/eval.
    """
    venv = build_vec_env(task, seed=seed, normalize_obs=False)

    if vec_path is not None:
        venv = VecNormalize.load(vec_path, venv)
        venv.training = False
        venv.norm_reward = False

    model = SAC.load(model_path, env=venv)
    return model, venv


In [None]:
import torch

@torch.no_grad()
def sac_policy_params(model: SAC, obs_batch: np.ndarray):
    """
    obs_batch: (n_envs, obs_dim) in VecEnv format (here n_envs=1).
    Returns:
      mu:      (n_envs, act_dim)
      log_std: (n_envs, act_dim)
    """
    obs_t = torch.as_tensor(obs_batch).to(model.device)

    # SB3 SAC actor helper
    mu_t, log_std_t, _ = model.policy.actor.get_action_dist_params(obs_t)

    mu = mu_t.detach().cpu().numpy()
    log_std = log_std_t.detach().cpu().numpy()
    return mu, log_std


In [None]:
def collect_memory_from_sac_teacher(
    model: SAC,
    venv,
    task_name: str,
    n_steps: int = 100_000,
    deterministic_action: bool = True,
    store_actions: bool = True,
    seed: int = 0,
) -> Dict[str, Any]:
    """
    Collects memory dataset: obs_norm, mu, log_std, (optional) action.
    NOTE: obs from VecNormalize-wrapped venv are already normalized.
    """
    venv.seed(seed)
    obs = venv.reset()

    obs_list = []
    mu_list = []
    logstd_list = []
    act_list = []

    for _ in range(n_steps):
        mu, log_std = sac_policy_params(model, obs)

        action, _ = model.predict(obs, deterministic=deterministic_action)

        obs_list.append(obs.copy())
        mu_list.append(mu.copy())
        logstd_list.append(log_std.copy())
        if store_actions:
            act_list.append(action.copy())

        obs, reward, done, info = venv.step(action)

        if bool(done[0]):
            obs = venv.reset()

    data = {
        "task": task_name,
        "obs": np.concatenate(obs_list, axis=0),        # (n_steps, obs_dim)
        "mu": np.concatenate(mu_list, axis=0),          # (n_steps, act_dim)
        "log_std": np.concatenate(logstd_list, axis=0), # (n_steps, act_dim)
    }
    if store_actions:
        data["action"] = np.concatenate(act_list, axis=0)

    return data


In [None]:
def save_memory_npz(data: Dict[str, Any], out_path: str):
    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    np.savez_compressed(
        out_path,
        task=data["task"],
        obs=data["obs"],
        mu=data["mu"],
        log_std=data["log_std"],
        **({"action": data["action"]} if "action" in data else {}),
    )
    print("Saved memory:", out_path)


In [None]:
MEM_DIR = "./memory_sac"
all_mem_paths = []

for i, t in enumerate(tasks):

    r = results[i]

    model, venv = load_sac_teacher(task, r["model_path"], r["vec_path"], seed=0)

    mem = collect_memory_from_sac_teacher(
        model=model,
        venv=venv,
        task_name=task.name,
        n_steps=50_000,                # start small to validate
        deterministic_action=True,     # or False to cover more state space
        store_actions=True,
        seed=123
    )

    out_path = os.path.join(MEM_DIR, f"{task.name}_SAC_memory.npz")
    save_memory_npz(mem, out_path)
    all_mem_paths.append(out_path)

    venv.close()

all_mem_paths
