In [2]:
!pip install -r requirements.txt

Collecting stable-baselines3[extra] (from -r requirements.txt (line 3))
  Downloading stable_baselines3-2.7.1-py3-none-any.whl.metadata (4.8 kB)
Collecting mujoco>=2.1.5 (from gymnasium[mujoco]->-r requirements.txt (line 2))
  Downloading mujoco-3.4.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (41 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.0/42.0 kB[0m [31m4.2 MB/s[0m eta [36m0:00:00[0m
Collecting glfw (from mujoco>=2.1.5->gymnasium[mujoco]->-r requirements.txt (line 2))
  Downloading glfw-2.10.0-py2.py27.py3.py30.py31.py32.py33.py34.py35.py36.py37.py38.p39.p310.p311.p312.p313-none-manylinux_2_28_x86_64.whl.metadata (5.4 kB)
Downloading mujoco-3.4.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (7.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.0/7.0 MB[0m [31m88.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading stable_baselines3-2.7.1-py3-none-any.whl (188 kB)
[2K   [90m━━━━━━━━━━━━━━━━━

##### Import Libaries

In [3]:
import numpy as np
import gymnasium as gym
from gymnasium import Wrapper
from dataclasses import dataclass
from typing import Callable, Dict, Any, Tuple, List
import os
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
from stable_baselines3 import PPO, SAC, TD3
from stable_baselines3.common.evaluation import evaluate_policy

Gym has been unmaintained since 2022 and does not support NumPy 2.0 amongst other critical functionality.
Please upgrade to Gymnasium, the maintained drop-in replacement of Gym, or contact the authors of your software and request that they upgrade.
See the migration guide at https://gymnasium.farama.org/introduction/migration_guide/ for additional information.
  return datetime.utcnow().replace(tzinfo=utc)


##### Task Variants for Cheetah (Target Velocity)

In [4]:
class HalfCheetahTargetVelocity(Wrapper):
    def __init__(self, env, target_velocity: float, vel_scale: float = 1.0, ctrl_cost_weight: float = 0.1):
        super().__init__(env)
        self.vt = float(target_velocity)
        self.vel_scale = float(vel_scale)
        self.ctrl_cost_weight = float(ctrl_cost_weight)

    def step(self, action):
        obs, _, terminated, truncated, info = self.env.step(action)

        vx = float(self.env.unwrapped.data.qvel[0])  # forward velocity
        vel_reward = -abs(vx - self.vt) * self.vel_scale
        ctrl_cost = self.ctrl_cost_weight * float(np.sum(action**2))
        reward = vel_reward - ctrl_cost

        info = dict(info)
        info.update({"vx": vx, "target_v": self.vt})
        return obs, reward, terminated, truncated, info


##### Task Variants for Ant (Target Direction)

In [5]:
class AntTargetDirection(Wrapper):
    def __init__(self, env, direction: np.ndarray, ctrl_cost_weight: float = 0.05):
        super().__init__(env)
        d = np.asarray(direction, dtype=np.float32)
        self.dir = d / (np.linalg.norm(d) + 1e-8)
        self.ctrl_cost_weight = float(ctrl_cost_weight)
        self._prev_xy = None

    def reset(self, **kwargs):
        obs, info = self.env.reset(**kwargs)
        self._prev_xy = np.array(self.env.unwrapped.data.qpos[0:2], dtype=np.float32)
        return obs, info

    def step(self, action):
        obs, _, terminated, truncated, info = self.env.step(action)

        xy = np.array(self.env.unwrapped.data.qpos[0:2], dtype=np.float32)
        dt = float(self.env.unwrapped.dt)
        vel_xy = (xy - self._prev_xy) / max(dt, 1e-8)
        self._prev_xy = xy

        dir_speed = float(np.dot(vel_xy, self.dir))
        ctrl_cost = self.ctrl_cost_weight * float(np.sum(action**2))
        reward = dir_speed - ctrl_cost

        info = dict(info)
        info.update({"vel_xy": vel_xy, "dir": self.dir, "dir_speed": dir_speed})
        return obs, reward, terminated, truncated, info


##### Task Variants Walker (Target Height)

In [6]:
class Walker2dTargetHeight(Wrapper):
    def __init__(self, env, target_height: float, height_scale: float = 1.0, ctrl_cost_weight: float = 0.001):
        super().__init__(env)
        self.ht = float(target_height)
        self.height_scale = float(height_scale)
        self.ctrl_cost_weight = float(ctrl_cost_weight)

    def step(self, action):
        obs, _, terminated, truncated, info = self.env.step(action)

        height = float(self.env.unwrapped.data.qpos[1])  # torso height (typ.)
        height_reward = -abs(height - self.ht) * self.height_scale
        ctrl_cost = self.ctrl_cost_weight * float(np.sum(action**2))
        reward = height_reward - ctrl_cost

        info = dict(info)
        info.update({"height": height, "target_h": self.ht})
        return obs, reward, terminated, truncated, info

In [7]:
def task_base(env_id: str, seed: int = 0):
    env = gym.make(env_id)
    env.reset(seed=seed)
    return env

def task_halfcheetah_target_velocity(target_v: float, seed: int = 0):
    env = gym.make("HalfCheetah-v4")
    env = HalfCheetahTargetVelocity(env, target_velocity=target_v)
    env.reset(seed=seed)
    return env

def task_ant_target_direction(dx: float, dy: float, seed: int = 0):
    env = gym.make("Ant-v4")
    env = AntTargetDirection(env, direction=np.array([dx, dy], dtype=np.float32))
    env.reset(seed=seed)
    return env

def task_walker2d_target_height(target_h: float, seed: int = 0):
    env = gym.make("Walker2d-v4")
    env = Walker2dTargetHeight(env, target_height=target_h)
    env.reset(seed=seed)
    return env


In [8]:
@dataclass(frozen=True)
class Task:
    name: str
    make_env: Callable[[], gym.Env]


tasks: List[Task] = [
    Task("BASE HalfCheetah", lambda: task_base("HalfCheetah-v4", seed=0)),
    Task("BASE Hopper",      lambda: task_base("Hopper-v4", seed=0)),
    Task("BASE Walker2d",    lambda: task_base("Walker2d-v4", seed=0)),
    Task("BASE Ant",         lambda: task_base("Ant-v4", seed=0))

    #Task("HC target_v=1.0",  lambda: task_halfcheetah_target_velocity(1.0, seed=1)),
    #Task("HC target_v=2.0",  lambda: task_halfcheetah_target_velocity(2.0, seed=2)),

    #Task("Ant EAST",         lambda: task_ant_target_direction(1.0, 0.0, seed=3)),
    #Task("Ant NORTH",        lambda: task_ant_target_direction(0.0, 1.0, seed=4)),

    #Task("Walker h=1.2",     lambda: task_walker2d_target_height(1.2, seed=5)),
    #Task("Walker h=1.6",     lambda: task_walker2d_target_height(1.6, seed=6)),
]


In [9]:
def build_vec_env(task: Task, seed: int = 0, normalize_obs: bool = True):
    def _init():
        env = task.make_env()
        env = Monitor(env)  # logs episode returns/lengths
        env.reset(seed=seed)
        return env

    venv = DummyVecEnv([_init])

    if normalize_obs:
        venv = VecNormalize(venv, norm_obs=True, norm_reward=False, clip_obs=10.0)

    return venv


In [10]:
def make_teacher(algo: str, env, seed: int = 0, logdir: str = None):
    algo = algo.upper()
    common = dict(verbose=1, seed=seed, tensorboard_log=logdir)

    if algo == "SAC":
        return SAC("MlpPolicy", env, batch_size=256, learning_rate=3e-4, gamma=0.99, **common)

    if algo == "TD3":
        return TD3("MlpPolicy", env, batch_size=256, learning_rate=1e-3, gamma=0.99, **common)

    if algo == "PPO":
        return PPO("MlpPolicy", env, n_steps=2048, batch_size=64, learning_rate=3e-4, gamma=0.99, **common)

    raise ValueError(f"Unknown algo: {algo}")


def train_teacher_for_task(
    task: Task,
    algo: str = "SAC",
    total_timesteps: int = 300_000,
    seed: int = 0,
    normalize_obs: bool = True,
    out_dir: str = "./teachers",
    log_dir: str = "./tb_logs",
):
    os.makedirs(out_dir, exist_ok=True)
    os.makedirs(log_dir, exist_ok=True)

    # Build env
    venv = build_vec_env(task, seed=seed, normalize_obs=normalize_obs)

    # Train teacher
    model = make_teacher(algo, venv, seed=seed, logdir=log_dir)
    model.learn(total_timesteps=total_timesteps, progress_bar=True)

    # Evaluate (freeze normalization updates)
    venv.training = False
    venv.norm_reward = False

    mean_r, std_r = evaluate_policy(model, venv, n_eval_episodes=10, deterministic=True)
    print(f"[{task.name}] {algo} eval: {mean_r:.2f} +/- {std_r:.2f}")

    # Save
    model_path = os.path.join(out_dir, f"{task.name}_{algo}.zip")
    model.save(model_path)

    vec_path = None
    if isinstance(venv, VecNormalize):
        vec_path = os.path.join(out_dir, f"{task.name}_{algo}_vecnormalize.pkl")
        venv.save(vec_path)

    venv.close()
    return {"task": task.name, "algo": algo, "mean": mean_r, "std": std_r, "model_path": model_path, "vec_path": vec_path}



In [11]:
results = []

for i, task in enumerate(tasks):
    res = train_teacher_for_task(
        task=task,
        algo="SAC",
        total_timesteps=500000,
        seed=100 + i,
        normalize_obs=True,
        out_dir="./teachers",
        log_dir="./tb_logs",
    )
    results.append(res)

## tensorboard --logdir tb_logs


[BASE Ant] SAC eval: 3602.85 +/- 70.70


In [12]:
results

[{'task': 'BASE HalfCheetah',
  'algo': 'SAC',
  'mean': np.float64(8743.4280451),
  'std': np.float64(122.44090484187974),
  'model_path': './teachers/BASE HalfCheetah_SAC.zip',
  'vec_path': './teachers/BASE HalfCheetah_SAC_vecnormalize.pkl'},
 {'task': 'BASE Hopper',
  'algo': 'SAC',
  'mean': np.float64(3534.9982952),
  'std': np.float64(74.30036539246545),
  'model_path': './teachers/BASE Hopper_SAC.zip',
  'vec_path': './teachers/BASE Hopper_SAC_vecnormalize.pkl'},
 {'task': 'BASE Walker2d',
  'algo': 'SAC',
  'mean': np.float64(4432.8367175),
  'std': np.float64(88.3051954926245),
  'model_path': './teachers/BASE Walker2d_SAC.zip',
  'vec_path': './teachers/BASE Walker2d_SAC_vecnormalize.pkl'},
 {'task': 'BASE Ant',
  'algo': 'SAC',
  'mean': np.float64(3602.8521288),
  'std': np.float64(70.70085649305115),
  'model_path': './teachers/BASE Ant_SAC.zip',
  'vec_path': './teachers/BASE Ant_SAC_vecnormalize.pkl'}]

##### Load Teacher for Memory Creation

In [None]:
import os
import numpy as np
import gymnasium as gym

from dataclasses import dataclass
from typing import Callable, Optional, Dict, Any, List

from stable_baselines3 import SAC
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize

def load_sac_teacher(task: Task, model_path: str, vec_path: Optional[str], seed: int = 0):
    """
    Loads SAC model + VecNormalize stats (if provided) for correct obs normalization.
    Returns (model, venv) where venv is ready for inference/eval.
    """
    venv = build_vec_env(task, seed=seed, normalize_obs=False)

    if vec_path is not None:
        venv = VecNormalize.load(vec_path, venv)
        venv.training = False
        venv.norm_reward = False

    model = SAC.load(model_path, env=venv)
    return model, venv


In [None]:
import torch

@torch.no_grad()
def sac_policy_params(model: SAC, obs_batch: np.ndarray):
    """
    obs_batch: (n_envs, obs_dim) in VecEnv format (here n_envs=1).
    Returns:
      mu:      (n_envs, act_dim)
      log_std: (n_envs, act_dim)
    """
    obs_t = torch.as_tensor(obs_batch).to(model.device)

    # SB3 SAC actor helper
    mu_t, log_std_t, _ = model.policy.actor.get_action_dist_params(obs_t)

    mu = mu_t.detach().cpu().numpy()
    log_std = log_std_t.detach().cpu().numpy()
    return mu, log_std


In [None]:
def collect_memory_from_sac_teacher(
    model: SAC,
    venv,
    task_name: str,
    n_steps: int = 100_000,
    deterministic_action: bool = True,
    store_actions: bool = True,
    seed: int = 0,
) -> Dict[str, Any]:
    """
    Collects memory dataset: obs_norm, mu, log_std, (optional) action.
    NOTE: obs from VecNormalize-wrapped venv are already normalized.
    """
    venv.seed(seed)
    obs = venv.reset()

    obs_list = []
    mu_list = []
    logstd_list = []
    act_list = []

    for _ in range(n_steps):
        mu, log_std = sac_policy_params(model, obs)

        action, _ = model.predict(obs, deterministic=deterministic_action)

        obs_list.append(obs.copy())
        mu_list.append(mu.copy())
        logstd_list.append(log_std.copy())
        if store_actions:
            act_list.append(action.copy())

        obs, reward, done, info = venv.step(action)

        if bool(done[0]):
            obs = venv.reset()

    data = {
        "task": task_name,
        "obs": np.concatenate(obs_list, axis=0),        # (n_steps, obs_dim)
        "mu": np.concatenate(mu_list, axis=0),          # (n_steps, act_dim)
        "log_std": np.concatenate(logstd_list, axis=0), # (n_steps, act_dim)
    }
    if store_actions:
        data["action"] = np.concatenate(act_list, axis=0)

    return data


In [None]:
def save_memory_npz(data: Dict[str, Any], out_path: str):
    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    np.savez_compressed(
        out_path,
        task=data["task"],
        obs=data["obs"],
        mu=data["mu"],
        log_std=data["log_std"],
        **({"action": data["action"]} if "action" in data else {}),
    )
    print("Saved memory:", out_path)


In [None]:
MEM_DIR = "./memory_sac"
all_mem_paths = []

for i, t in enumerate(tasks):

    r = results[i]

    model, venv = load_sac_teacher(t, r["model_path"], r["vec_path"], seed=0)

    mem = collect_memory_from_sac_teacher(
        model=model,
        venv=venv,
        task_name=t.name,
        n_steps=50_000,                # start small to validate
        deterministic_action=True,     # or False to cover more state space
        store_actions=True,
        seed=123
    )

    out_path = os.path.join(MEM_DIR, f"{t.name}_SAC_memory.npz")
    save_memory_npz(mem, out_path)
    all_mem_paths.append(out_path)

    venv.close()

all_mem_paths


In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

In [None]:
class DistillMemoryDataset(Dataset):
    def __init__(self, npz_path: str):
        d = np.load(npz_path, allow_pickle=True)
        self.obs = d["obs"].astype(np.float32)
        self.mu_t = d["mu"].astype(np.float32)
        self.log_std_t = d["log_std"].astype(np.float32)
        self.action_t = d["action"].astype(np.float32) if "action" in d.files else None

    def __len__(self):
        return self.obs.shape[0]

    def __getitem__(self, idx):
        obs = self.obs[idx]
        mu_t = self.mu_t[idx]
        log_std_t = self.log_std_t[idx]
        if self.action_t is None:
            return obs, mu_t, log_std_t
        return obs, mu_t, log_std_t, self.action_t[idx]


In [None]:
class GaussianStudentPolicy(nn.Module):
    def __init__(self, obs_dim: int, act_dim: int, hidden=(256, 256), log_std_bounds=(-5.0, 2.0)):
        super().__init__()
        self.log_std_min, self.log_std_max = log_std_bounds

        layers = []
        in_dim = obs_dim
        for h in hidden:
            layers += [nn.Linear(in_dim, h), nn.ReLU()]
            in_dim = h
        self.backbone = nn.Sequential(*layers)

        self.mu_head = nn.Linear(in_dim, act_dim)
        self.log_std_head = nn.Linear(in_dim, act_dim)

    def forward(self, obs: torch.Tensor):
        x = self.backbone(obs)
        mu = self.mu_head(x)
        log_std = self.log_std_head(x)
        log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
        return mu, log_std


##### Distillation Method 1 and 2 use soft and hard label actions

In [None]:
# D1
def diag_gaussian_kl(mu_t, log_std_t, mu_s, log_std_s):
    # shapes: (B, act_dim)
    std_t = torch.exp(log_std_t)
    std_s = torch.exp(log_std_s)

    var_t = std_t ** 2
    var_s = std_s ** 2

    kl = (log_std_s - log_std_t) + (var_t + (mu_t - mu_s) ** 2) / (2.0 * var_s) - 0.5
    return kl.sum(dim=-1).mean()  # mean over batch


In [None]:
# D2
def action_mse(mu_s, action_t):
    return F.mse_loss(mu_s, action_t)

##### Distillation Method 3 uses weighted certainty. States where the teacher is sure what to do are weighted harder

In [None]:
# D3
def certainty_weights(log_std_t, eps=1e-6):
    # weight per sample (B,)
    std_t = torch.exp(log_std_t)              # (B, act_dim)
    w = 1.0 / (eps + std_t.mean(dim=-1))      # (B,)
    # normalize weights to keep scale stable
    w = w / (w.mean() + 1e-8)
    return w

def weighted_diag_gaussian_kl(mu_t, log_std_t, mu_s, log_std_s):
    std_t = torch.exp(log_std_t)
    std_s = torch.exp(log_std_s)
    var_t = std_t ** 2
    var_s = std_s ** 2

    kl_per_dim = (log_std_s - log_std_t) + (var_t + (mu_t - mu_s) ** 2) / (2.0 * var_s) - 0.5
    kl_per_sample = kl_per_dim.sum(dim=-1)  # (B,)

    w = certainty_weights(log_std_t)         # (B,)
    return (w * kl_per_sample).mean()


In [None]:
def train_offline_distill(
    npz_path: str,
    method: str,
    epochs: int = 10,
    batch_size: int = 256,
    lr: float = 3e-4,
    device: str = "cuda" if torch.cuda.is_available() else "cpu",
):
    ds = DistillMemoryDataset(npz_path)
    dl = DataLoader(ds, batch_size=batch_size, shuffle=True, drop_last=True)

    obs_dim = ds.obs.shape[1]
    act_dim = ds.mu_t.shape[1]
    student = GaussianStudentPolicy(obs_dim, act_dim).to(device)
    opt = torch.optim.Adam(student.parameters(), lr=lr)

    method = method.upper()
    for ep in range(1, epochs + 1):
        losses = []
        for batch in dl:
            opt.zero_grad()

            if len(batch) == 3:
                obs, mu_t, log_std_t = batch
                action_t = None
            else:
                obs, mu_t, log_std_t, action_t = batch

            obs = obs.to(device)
            mu_t = mu_t.to(device)
            log_std_t = log_std_t.to(device)
            if action_t is not None:
                action_t = action_t.to(device)

            mu_s, log_std_s = student(obs)

            if method == "D1_KL":
                loss = diag_gaussian_kl(mu_t, log_std_t, mu_s, log_std_s)

            elif method == "D2_MSE":
                if action_t is None:
                    raise ValueError("D2_MSE needs 'action' stored in npz.")
                loss = action_mse(mu_s, action_t)

            elif method == "D3_WKL":
                loss = weighted_diag_gaussian_kl(mu_t, log_std_t, mu_s, log_std_s)

            else:
                raise ValueError("Unknown method. Use: D1_KL, D2_MSE, D3_WKL")

            loss.backward()
            opt.step()
            losses.append(loss.item())

        print(f"Epoch {ep:02d} | {method} loss: {np.mean(losses):.4f}")

    return student



In [None]:
def save_student(student, path: str):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    torch.save(student.state_dict(), path)
    print("Saved student:", path)


##### Distillation Method 4 uses normal RL for the student, but guides it through the teacher memory

In [None]:
def diag_gaussian_kl_torch(mu_t, log_std_t, mu_s, log_std_s):
    std_t = torch.exp(log_std_t); std_s = torch.exp(log_std_s)
    var_t = std_t**2; var_s = std_s**2
    kl = (log_std_s - log_std_t) + (var_t + (mu_t - mu_s)**2) / (2.0 * var_s) - 0.5
    return kl.sum(dim=-1).mean()

def sac_actor_distill_step(student_sac, obs_np, mu_teacher_np, log_std_teacher_np):
    device = student_sac.device
    actor = student_sac.policy.actor
    opt = actor.optimizer

    obs = torch.as_tensor(obs_np, dtype=torch.float32, device=device)
    mu_t = torch.as_tensor(mu_teacher_np, dtype=torch.float32, device=device)
    ls_t = torch.as_tensor(log_std_teacher_np, dtype=torch.float32, device=device)

    mu_s, ls_s, _ = actor.get_action_dist_params(obs)
    loss = diag_gaussian_kl_torch(mu_t, ls_t, mu_s, ls_s)

    opt.zero_grad()
    loss.backward()
    opt.step()
    return float(loss.item())


In [None]:
def train_kickstarting(
    student_sac,
    memory_npz_path: str,
    total_timesteps: int = 300_000,
    chunk: int = 20_000,
    distill_steps_per_chunk: int = 200,
    distill_batch_size: int = 256,
):
    mem = np.load(memory_npz_path)
    obs_mem = mem["obs"].astype(np.float32)
    mu_mem  = mem["mu"].astype(np.float32)
    ls_mem  = mem["log_std"].astype(np.float32)
    n = obs_mem.shape[0]

    trained = 0
    while trained < total_timesteps:
        student_sac.learn(total_timesteps=chunk, reset_num_timesteps=False, progress_bar=True)
        trained += chunk

        losses = []
        for _ in range(distill_steps_per_chunk):
            idx = np.random.randint(0, n, size=(distill_batch_size,))
            loss = sac_actor_distill_step(student_sac, obs_mem[idx], mu_mem[idx], ls_mem[idx])
            losses.append(loss)

        print(f"After {trained} steps | distill loss mean: {np.mean(losses):.4f}")


##### Eval in a Normalized Env

In [None]:
import gymnasium as gym
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
from stable_baselines3.common.monitor import Monitor
import torch
import numpy as np

def make_base_vec_env(env_id: str, seed: int = 0):
    def _init():
        env = gym.make(env_id)
        env = Monitor(env)
        env.reset(seed=seed)
        return env
    return DummyVecEnv([_init])

def load_eval_env_with_vecnorm(env_id: str, vec_path: str, seed: int = 0):
    venv = make_base_vec_env(env_id, seed=seed)
    venv = VecNormalize.load(vec_path, venv)
    venv.training = False
    venv.norm_reward = False
    return venv

@torch.no_grad()
def eval_offline_student(student, venv, n_episodes=10, device=None):
    if device is None:
        device = next(student.parameters()).device
    student.eval()

    rets = []
    for _ in range(n_episodes):
        obs = venv.reset()        # normalized obs (shape (1, obs_dim))
        done = [False]
        ep_ret = 0.0

        while not done[0]:
            obs_t = torch.as_tensor(obs, dtype=torch.float32, device=device)
            mu, log_std = student(obs_t)

            # MuJoCo expects actions in [-1, 1]; match SAC-style squashing:
            action = torch.tanh(mu).cpu().numpy()

            obs, reward, done, info = venv.step(action)
            ep_ret += float(reward[0])

        rets.append(ep_ret)

    return float(np.mean(rets)), float(np.std(rets))


##### Training Run for Distillation

In [None]:
npz_path = "./memory_sac/BASE_HalfCheetah_SAC_memory.npz"
env_id = "HalfCheetah-v4"
vec_path = "./teachers/BASE_HalfCheetah_SAC_vecnormalize.pkl"

student_d1 = train_offline_distill(npz_path, "D1_KL", epochs=10)
student_d2 = train_offline_distill(npz_path, "D2_MSE", epochs=10)
student_d3 = train_offline_distill(npz_path, "D3_WKL", epochs=10)

venv_eval = load_eval_env_with_vecnorm(env_id, vec_path, seed=0)

print("D1:", eval_offline_student(student_d1, venv_eval))
print("D2:", eval_offline_student(student_d2, venv_eval))
print("D3:", eval_offline_student(student_d3, venv_eval))

venv_eval.close()


In [None]:
from stable_baselines3 import SAC

venv_student = load_eval_env_with_vecnorm(env_id, vec_path, seed=0)  # normalized env

student_sac = SAC("MlpPolicy", venv_student, verbose=1, seed=0, batch_size=256, learning_rate=3e-4)
train_kickstarting(
    student_sac,
    memory_npz_path=npz_path,
    total_timesteps=300_000,
    chunk=20_000,
    distill_steps_per_chunk=200,
    distill_batch_size=256,
)

rewards = evaluate_policy(student_sac, venv_student)