##### Import Libaries

In [1]:
from typing import Callable, Dict, Any, Tuple, List
import os
import numpy as np
from stable_baselines3.common.evaluation import evaluate_policy
from src.teacher import Task
from src.teacher import train_teacher_for_task, task_base, task_halfcheetah_target_velocity
from src.teacher import task_walker2d_target_velocity
from src.memory import load_sac_teacher, collect_memory_from_sac_teacher, save_memory_npz

import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

In [None]:
tasks = [
    Task("BASE HalfCheetah", lambda: task_base("HalfCheetah-v4", seed=0)),
    Task("BASE Walker2d",    lambda: task_base("Walker2d-v4", seed=0)),
]

halfcheetah_tasks = [
    Task("HC_WALK_v1.0",  lambda: task_halfcheetah_target_velocity( 1.0, seed=0)),
    Task("HC_RUN_v6.0",   lambda: task_halfcheetah_target_velocity( 6.0, seed=0)),
    Task("HC_BACK_v-1.0", lambda: task_halfcheetah_target_velocity(-1.0, seed=0)),
]

walker_tasks = [
    Task("W_WALK_v1.0",  lambda: task_walker2d_target_velocity( 1.0, seed=1)),
    Task("W_RUN_v3.5",   lambda: task_walker2d_target_velocity( 3.5, seed=1)),
    Task("W_BACK_v-1.0", lambda: task_walker2d_target_velocity(-1.0, seed=1)),
    # Optional (add later):
    # Task("W_JUMP", lambda: task_walker2d_jump(seed=1, baseline_height=1.25, beta=5.0)),
]



In [4]:
results = []
for t in halfcheetah_tasks:
    res = train_teacher_for_task(
        task=t,
        algo="SAC",
        total_timesteps=1_000_000,  # recommended for shaped tasks; 300k may be low
        seed=0,
        normalize_obs=True,
        out_dir="./teachers",
        log_dir="./tb_logs",
    )
    results.append(res)

results

  logger.deprecation(


Using cpu device
Logging to ./tb_logs\SAC_3


KeyboardInterrupt: 

In [None]:
results = [{'task': 'BASE HalfCheetah', 'algo': 'SAC', 'mean': np.float64(8743.4280451), 'std': np.float64(122.44090484187974), 'model_path': './teachers/BASE HalfCheetah_SAC.zip', 'vec_path': './teachers/BASE HalfCheetah_SAC_vecnormalize.pkl'}, 
{'task': 'BASE Hopper', 'algo': 'SAC', 'mean': np.float64(3534.9982952), 'std': np.float64(74.30036539246545), 'model_path': './teachers/BASE Hopper_SAC.zip', 'vec_path': './teachers/BASE Hopper_SAC_vecnormalize.pkl'}, 
{'task': 'BASE Walker2d', 'algo': 'SAC', 'mean': np.float64(4432.8367175), 'std': np.float64(88.3051954926245), 'model_path': './teachers/BASE Walker2d_SAC.zip', 'vec_path': './teachers/BASE Walker2d_SAC_vecnormalize.pkl'}, 
{'task': 'BASE Ant', 'algo': 'SAC', 'mean': np.float64(3602.8521288), 'std': np.float64(70.70085649305115), 'model_path': './teachers/BASE Ant_SAC.zip', 'vec_path': './teachers/BASE Ant_SAC_vecnormalize.pkl'}]


##### Load Teacher for Memory Creation

In [None]:
MEM_DIR = "./memory_sac"
all_mem_paths = []

for i, t in enumerate(tasks):

    r = results[i]

    model, venv = load_sac_teacher(t, r["model_path"], r["vec_path"], seed=0)

    mem = collect_memory_from_sac_teacher(
        model=model,
        venv=venv,
        task_name=t.name,
        n_steps=50_000,                # start small to validate
        deterministic_action=True,     # or False to cover more state space
        store_actions=True,
        seed=123
    )

    out_path = os.path.join(MEM_DIR, f"{t.name}_SAC_memory.npz")
    save_memory_npz(mem, out_path)
    all_mem_paths.append(out_path)

    venv.close()

all_mem_paths


In [None]:
class DistillMemoryDataset(Dataset):
    def __init__(self, npz_path: str):
        d = np.load(npz_path, allow_pickle=True)
        self.obs = d["obs"].astype(np.float32)
        self.mu_t = d["mu"].astype(np.float32)
        self.log_std_t = d["log_std"].astype(np.float32)
        self.action_t = d["action"].astype(np.float32) if "action" in d.files else None

    def __len__(self):
        return self.obs.shape[0]

    def __getitem__(self, idx):
        obs = self.obs[idx]
        mu_t = self.mu_t[idx]
        log_std_t = self.log_std_t[idx]
        if self.action_t is None:
            return obs, mu_t, log_std_t
        return obs, mu_t, log_std_t, self.action_t[idx]


In [None]:

class GaussianStudentPolicy(nn.Module):
    def __init__(self, obs_dim, act_dim, hidden=(256, 256), log_std_bounds=(-5.0, 2.0)):
        super().__init__()
        self.log_std_min, self.log_std_max = log_std_bounds

        layers = []
        in_dim = obs_dim
        for h in hidden:
            layers += [nn.Linear(in_dim, h), nn.ReLU()]
            in_dim = h
        self.backbone = nn.Sequential(*layers)

        self.mu_head = nn.Linear(in_dim, act_dim)
        self.log_std_head = nn.Linear(in_dim, act_dim)

    def forward(self, obs, return_features=False):
        z = self.backbone(obs)  # student latent
        mu = self.mu_head(z)
        log_std = torch.clamp(self.log_std_head(z), self.log_std_min, self.log_std_max)
        if return_features:
            return mu, log_std, z
        return mu, log_std


##### Distillation Method 1 and 2 use soft and hard label actions

In [None]:
# D1
def diag_gaussian_kl(mu_t, log_std_t, mu_s, log_std_s):
    # shapes: (B, act_dim)
    std_t = torch.exp(log_std_t)
    std_s = torch.exp(log_std_s)

    var_t = std_t ** 2
    var_s = std_s ** 2

    kl = (log_std_s - log_std_t) + (var_t + (mu_t - mu_s) ** 2) / (2.0 * var_s) - 0.5
    return kl.sum(dim=-1).mean()  # mean over batch


In [None]:
# D2
def action_mse(mu_s, action_t):
    return F.mse_loss(mu_s, action_t)

##### Distillation Method 3 uses weighted certainty. States where the teacher is sure what to do are weighted harder

In [None]:
# D3
def certainty_weights(log_std_t, eps=1e-6):
    # weight per sample (B,)
    std_t = torch.exp(log_std_t)              # (B, act_dim)
    w = 1.0 / (eps + std_t.mean(dim=-1))      # (B,)
    # normalize weights to keep scale stable
    w = w / (w.mean() + 1e-8)
    return w

def weighted_diag_gaussian_kl(mu_t, log_std_t, mu_s, log_std_s):
    std_t = torch.exp(log_std_t)
    std_s = torch.exp(log_std_s)
    var_t = std_t ** 2
    var_s = std_s ** 2

    kl_per_dim = (log_std_s - log_std_t) + (var_t + (mu_t - mu_s) ** 2) / (2.0 * var_s) - 0.5
    kl_per_sample = kl_per_dim.sum(dim=-1)  # (B,)

    w = certainty_weights(log_std_t)         # (B,)
    return (w * kl_per_sample).mean()


##### Distillation Method 4 uses the internal representations of the teacher for the student

In [None]:
import torch

@torch.no_grad()
def sac_teacher_latent(model, obs_batch_np):
    """
    Returns teacher actor latent features for given (normalized) obs batch.
    obs_batch_np: (B, obs_dim) or (1, obs_dim)
    """
    actor = model.policy.actor
    obs = torch.as_tensor(obs_batch_np, dtype=torch.float32, device=model.device)

    # Most SB3 versions have features_extractor + latent_pi
    if hasattr(actor, "features_extractor") and hasattr(actor, "latent_pi"):
        feat = actor.features_extractor(obs)
        lat = actor.latent_pi(feat)
        return lat

    # Fallback: some versions use extract_features()
    if hasattr(actor, "extract_features") and hasattr(actor, "latent_pi"):
        feat = actor.extract_features(obs)
        lat = actor.latent_pi(feat)
        return lat

    raise AttributeError("Could not locate actor latent pathway. Inspect model.policy.actor to adapt extractor.")

class Projector(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.proj = nn.Linear(in_dim, out_dim)
    def forward(self, x):
        return self.proj(x)


def latent_cosine_loss(z_t, z_s_proj):
    # (B, D): 1 - cosine similarity
    return 1.0 - F.cosine_similarity(z_s_proj, z_t, dim=-1).mean()




In [None]:

def snapshot_params(model: torch.nn.Module):
    """Detached copy of all trainable parameters (for anchoring)."""
    return [p.detach().clone() for p in model.parameters() if p.requires_grad]

def anchor_loss(model: torch.nn.Module, anchor_params, coeff: float):
    """L2 penalty to keep parameters close to anchor snapshot."""
    if coeff <= 0 or anchor_params is None:
        return 0.0

    loss = 0.0
    i = 0
    for p in model.parameters():
        if p.requires_grad:
            loss = loss + torch.sum((p - anchor_params[i]) ** 2)
            i += 1
    return coeff * loss


In [None]:
def train_distill_step_no_replay(
    student,                       # <-- external student (keeps weights)
    method: str,
    current_npz: str,
    teacher_sac_model=None,        # only needed for D4_KL_LATENT
    projector=None,                # optional, for D4
    epochs: int = 10,
    batch_size: int = 256,
    lr: float = 3e-4,
    lambda_feat: float = 0.05,     # D4 only
    anchor_coeff: float = 1e-6,    # keep small; tunes stability vs plasticity
    device: str = "cuda" if torch.cuda.is_available() else "cpu",
):
    """
    Sequential offline distillation on *current task only* (no replay),
    with optional weight anchoring to reduce overwriting.

    Returns:
      student (same object, updated), projector (for D4 if used)
    """
    method = method.upper()

    # --- data (current task only) ---
    ds = DistillMemoryDataset(current_npz)
    dl = DataLoader(ds, batch_size=batch_size, shuffle=True, drop_last=True)

    student = student.to(device)

    # Anchor snapshot BEFORE learning this task
    anchor_params = snapshot_params(student) if anchor_coeff > 0 else None

    # --- projector setup for D4 ---
    if method == "D4_KL_LATENT":
        if teacher_sac_model is None:
            raise ValueError("D4_KL_LATENT requires teacher_sac_model.")

        if projector is None:
            # infer dims from small sample
            sample_obs = ds.obs[:8].astype(np.float32)

            with torch.no_grad():
                # student latent dim
                obs_t = torch.as_tensor(sample_obs, dtype=torch.float32, device=device)
                _, _, z_s = student(obs_t, return_features=True)
                student_lat_dim = int(z_s.shape[-1])

                # teacher latent dim
                z_t = sac_teacher_latent(teacher_sac_model, sample_obs).detach().cpu()
                teacher_lat_dim = int(z_t.shape[-1])

            projector = Projector(student_lat_dim, teacher_lat_dim).to(device)
        else:
            projector = projector.to(device)

        opt = torch.optim.Adam(list(student.parameters()) + list(projector.parameters()), lr=lr)
    else:
        opt = torch.optim.Adam(student.parameters(), lr=lr)

    # --- training loop ---
    for ep in range(1, epochs + 1):
        losses = []
        for batch in dl:
            opt.zero_grad()

            if len(batch) == 3:
                obs, mu_t, log_std_t = batch
                action_t = None
            else:
                obs, mu_t, log_std_t, action_t = batch

            obs = obs.to(device)
            mu_t = mu_t.to(device)
            log_std_t = log_std_t.to(device)
            if action_t is not None:
                action_t = action_t.to(device)

            # Forward student
            if method == "D4_KL_LATENT":
                mu_s, log_std_s, z_s = student(obs, return_features=True)
            else:
                mu_s, log_std_s = student(obs)

            # Base loss by method
            if method == "D1_KL":
                loss = diag_gaussian_kl(mu_t, log_std_t, mu_s, log_std_s)

            elif method == "D2_MSE":
                if action_t is None:
                    raise ValueError("D2_MSE needs 'action' stored in npz.")
                loss = action_mse(mu_s, action_t)

            elif method == "D3_WKL":
                loss = weighted_diag_gaussian_kl(mu_t, log_std_t, mu_s, log_std_s)

            elif method == "D4_KL_LATENT":
                loss_policy = diag_gaussian_kl(mu_t, log_std_t, mu_s, log_std_s)

                # teacher latent on-the-fly (obs are normalized already)
                z_t = sac_teacher_latent(teacher_sac_model, obs.detach().cpu().numpy()).to(device)

                z_s_proj = projector(z_s)
                loss_lat = latent_cosine_loss(z_t, z_s_proj)

                loss = loss_policy + lambda_feat * loss_lat

            else:
                raise ValueError("Unknown method. Use: D1_KL, D2_MSE, D3_WKL, D4_KL_LATENT")

            # Add anchor penalty (prevents large drift on new task)
            if anchor_coeff > 0:
                loss = loss + anchor_loss(student, anchor_params, anchor_coeff)

            loss.backward()
            opt.step()
            losses.append(float(loss.item()))

        # print occasionally
        if ep == 1 or ep % 10 == 0:
            extra = f" (lambda_feat={lambda_feat})" if method == "D4_KL_LATENT" else ""
            print(f"Epoch {ep:02d} | {method} loss: {np.mean(losses):.4f} | anchor={anchor_coeff}{extra}")

    return student, projector


In [None]:
def save_student(student, path: str):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    torch.save(student.state_dict(), path)
    print("Saved student:", path)


##### Eval in a Normalized Env

In [None]:
import gymnasium as gym
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
from stable_baselines3.common.monitor import Monitor
import torch
import numpy as np

def make_base_vec_env(env_id: str, seed: int = 0):
    def _init():
        env = gym.make(env_id)
        env = Monitor(env)
        env.reset(seed=seed)
        return env
    return DummyVecEnv([_init])

def load_eval_env_with_vecnorm(env_id: str, vec_path: str, seed: int = 0):
    venv = make_base_vec_env(env_id, seed=seed)
    venv = VecNormalize.load(vec_path, venv)
    venv.training = False
    venv.norm_reward = False
    return venv

@torch.no_grad()
def eval_offline_student(student, venv, n_episodes=10, device=None):
    if device is None:
        device = next(student.parameters()).device
    student.eval()

    rets = []
    for _ in range(n_episodes):
        obs = venv.reset()        # normalized obs (shape (1, obs_dim))
        done = [False]
        ep_ret = 0.0

        while not done[0]:
            obs_t = torch.as_tensor(obs, dtype=torch.float32, device=device)
            mu, log_std = student(obs_t)

            # MuJoCo expects actions in [-1, 1]; match SAC-style squashing:
            action = torch.tanh(mu).cpu().numpy()

            obs, reward, done, info = venv.step(action)
            ep_ret += float(reward[0])

        rets.append(ep_ret)

    return float(np.mean(rets)), float(np.std(rets))


##### Training Run for Distillation

In [None]:
TASK_SEQUENCE = [
    {
        "name": "BASE HalfCheetah",
        "env_id": "HalfCheetah-v4",
        "npz_path": "./memory_sac/BASE HalfCheetah_SAC_memory.npz",
        "model_path": "./teachers/BASE HalfCheetah_SAC.zip",
        "vec_path": "./teachers/BASE HalfCheetah_SAC_vecnormalize.pkl",
    },
    {
        "name": "BASE Hopper",
        "env_id": "Hopper-v4",
        "npz_path": "./memory_sac/BASE Hopper_SAC_memory.npz",
        "model_path": "./teachers/BASE Hopper_SAC.zip",
        "vec_path": "./teachers/BASE Hopper_SAC_vecnormalize.pkl",
    },
    {
        "name": "BASE Walker2d",
        "env_id": "Walker2d-v4",
        "npz_path": "./memory_sac/BASE Walker2d_SAC_memory.npz",
        "model_path": "./teachers/BASE Walker2d_SAC.zip",
        "vec_path": "./teachers/BASE Walker2d_SAC_vecnormalize.pkl",
    },
    {
        "name": "BASE Ant",
        "env_id": "Ant-v4",
        "npz_path": "./memory_sac/BASE Ant_SAC_memory.npz",
        "model_path": "./teachers/BASE Ant_SAC.zip",
        "vec_path": "./teachers/BASE Ant_SAC_vecnormalize.pkl",
    },
]

# infer obs_dim / act_dim from first task memory
tmp = np.load(TASK_SEQUENCE[0]["npz_path"], allow_pickle=True)
obs_dim = tmp["obs"].shape[1]
act_dim = tmp["mu"].shape[1]

student = GaussianStudentPolicy(obs_dim, act_dim)
projector = None  # used only for D4
methods = ["D1_KL", "D2_MSE", "D3_WKL", "D4_KL_LATENT"]
method = methods[0]

for i, cfg in enumerate(TASK_SEQUENCE):
    print(f"\n==============================")
    print(f" Training task {i+1}/{len(TASK_SEQUENCE)}: {cfg['name']}")
    print(f"==============================")

    # build Task object (same style you already use)
    task = Task(
        cfg["name"],
        lambda env_id=cfg["env_id"]: task_base(env_id, seed=0)
    )

    # load teacher
    if method == "D4_KL_LATENT":
        teacher_model, _ = load_sac_teacher(
            task,
            cfg["model_path"],
            cfg["vec_path"],
            seed=0
        )
            # ---- D4: KL + latent alignment ----
        student, projector = train_distill_step_no_replay(
            student=student,
            projector=None,                  
            method="D4_KL_LATENT",
            current_npz=cfg["npz_path"],
            teacher_sac_model=teacher_model,
            epochs=50,
            lambda_feat=0.2,
            anchor_coeff=1e-6,
        )
    else:
        student, _ = train_distill_step_no_replay(
            student=student,
            method=method,
            current_npz=cfg["npz_path"],
            epochs=50,
            anchor_coeff=1e-6,
        )
    
    # eval after each task
    venv_eval = load_eval_env_with_vecnorm(cfg["env_id"], cfg["vec_path"], seed=0)
    mean_ret, std_ret = eval_offline_student(student, venv_eval)
    print(f"--> Eval after task {i+1}: mean return = {mean_ret:.2f} +/- {std_ret:.2f}")
    venv_eval.close()

