In [None]:
!git clone https://github.com/tam4x/rl_project_25.git
!pip install -r rl_project_25/requirements.txt
import sys
sys.path.append("/content/rl_project_25")

##### Import Libaries

In [None]:
from typing import Callable, Dict, Any, Tuple, List
import os
import numpy as np
from stable_baselines3.common.evaluation import evaluate_policy
from src.teacher import *
from src.memory import load_sac_teacher, collect_memory_from_sac_teacher, save_memory_npz
from src.distillation import *
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

In [None]:
tasks = [
    Task("BASE HalfCheetah", lambda: task_base("HalfCheetah-v4", seed=0)),
    Task("BASE Walker2d",    lambda: task_base("Walker2d-v4", seed=0)),
]

halfcheetah_tasks = [
    Task("HC_WALK_v1.0",  lambda: task_halfcheetah_target_velocity( 1.0, seed=0)),
    Task("HC_RUN_v6.0",   lambda: task_halfcheetah_target_velocity( 6.0, seed=0)),
    Task("HC_BACK_v-1.0", lambda: task_halfcheetah_target_velocity(-1.0, seed=0)),
]

walker_tasks = [
    Task("W_WALK_v1.0",  lambda: task_walker2d_target_velocity( 1.0, seed=1)),
    Task("W_RUN_v3.5",   lambda: task_walker2d_target_velocity( 3.5, seed=1)),
    Task("W_BACK_v-1.0", lambda: task_walker2d_target_velocity(-1.0, seed=1)),
    # Optional (add later):
    # Task("W_JUMP", lambda: task_walker2d_jump(seed=1, baseline_height=1.25, beta=5.0)),
]

task_list = [
    {"env_fn": lambda: task_halfcheetah_target_velocity( 1.0, seed=0), 
     "model_path": "./teachers/HC_WALK_v1.0_SAC.zip", 
     "vec_path": "./teachers/HC_WALK_v1.0_SAC_vecnormalize.pkl", "task": 1.0},
    {"env_fn": lambda: task_halfcheetah_target_velocity( 6.0, seed=0), 
     "model_path": "./teachers/HC_RUN_v6.0_SAC.zip", 
     "vec_path": "./teachers/HC_RUN_v6.0_SAC_vecnormalize.pkl", "task": 6.0},
    {"env_fn": lambda: task_halfcheetah_target_velocity(-1.0, seed=0),
    "model_path": "./teachers/HC_BACK_v-1.0_SAC.zip", 
    "vec_path": "./teachers/HC_BACK_v-1.0_SAC_vecnormalize.pkl", "task": -1.0},
]


In [None]:
results = []
for t in halfcheetah_tasks:
    res = train_teacher_for_task(
        task=t,
        algo="SAC",
        total_timesteps=1_000_000,  # recommended for shaped tasks; 300k may be low
        seed=0,
        normalize_obs=True,
        out_dir="./teachers",
        log_dir="./tb_logs",
    )
    results.append(res)

results


##### Try out Teacher and visualize

In [None]:

reward = []
for t in task_list:
    mean_r, std_r = eval_teacher(
        model_path=t["model_path"],
        vec_path=t["vec_path"],
        make_env_fn=t["env_fn"],
        n_eval_episodes=10,
        seed=0
    )
    print(f"Task {t['task']}: mean reward = {mean_r} +/- {std_r}")
    reward.append(mean_r)

    #test_teacher_render(model_path=t["model_path"], vec_path=t["vec_path"], task=t["task"])

##### Load Teacher for Memory Creation

In [None]:
MEM_DIR = "./memory_sac"
os.makedirs(MEM_DIR, exist_ok=True)

all_mem_paths = []

N_TASKS = 3                 # total number of tasks
TASK_ID_ENCODING = "onehot"  # recommended

for task_id, t in enumerate(halfcheetah_tasks):
    assert task_id < N_TASKS, "More tasks than expected!"

    attr = task_list[task_id]

    # Load teacher + env
    model, venv = load_sac_teacher(
        task=t,
        model_path=attr["model_path"],
        vec_path=attr["vec_path"],
        seed=0,
    )

    # Collect memory WITH task-id in obs
    mem = collect_memory_from_sac_teacher(
        model=model,
        venv=venv,
        task_name=t.name,
        task_id=task_id,              # ← CRITICAL
        n_tasks=N_TASKS,              # ← CRITICAL
        task_id_encoding=TASK_ID_ENCODING,
        n_steps=200_000,
        deterministic_action=False,   # better coverage
        store_actions=True,
        seed=123,
    )

    out_path = os.path.join(
        MEM_DIR, f"{t.name}_task{task_id}_SAC_memory.npz"
    )
    save_memory_npz(mem, out_path)

    all_mem_paths.append(out_path)
    venv.close()

all_mem_paths


##### Training Run for Distillation

In [None]:
import os

MEM_DIR = "./memory_sac"

# halfcheetah_tasks and task_list must be aligned (same order)
assert len(task_list) == len(halfcheetah_tasks), "task_list and halfcheetah_tasks must have same length"

TASK_SEQUENCE = []
for i, (t, attr) in enumerate(zip(halfcheetah_tasks, task_list)):
    TASK_SEQUENCE.append({
        "name": t.name,                         # e.g. "HC_WALK_v1.0"
        "env_id": None,                         # not needed for these custom env_fns
        "env_fn": attr["env_fn"],               
        "npz_path": MEM_DIR + f"/{t.name}_task{i}_SAC_memory.npz",
        "model_path": attr["model_path"],
        "vec_path": attr["vec_path"],
        "task_value": attr["task"],             # e.g. 1.0 / 6.0 / -1.0 (optional metadata)
        "task_id": i,                           # 0,1,2
        "n_tasks": len(halfcheetah_tasks),      # 3
    })
print(TASK_SEQUENCE)

In [None]:
N_TASKS = len(TASK_SEQUENCE)

# --- infer act_dim from memory (stable) ---
tmp = np.load(TASK_SEQUENCE[0]["npz_path"], allow_pickle=True)
act_dim = tmp["mu"].shape[1]

# --- infer ORIGINAL obs_dim from the env (not from memory) ---
# Build the first task env the same way you do in your loop.
first_cfg = TASK_SEQUENCE[0]

# If you use custom env_fn:
task0 = Task(first_cfg["name"], first_cfg["env_fn"])

venv_tmp = build_vec_env(task0, seed=0, normalize_obs=False)
obs0 = venv_tmp.reset()
orig_obs_dim = obs0.shape[1]
venv_tmp.close()

student_obs_dim = orig_obs_dim + N_TASKS

print(f"Original env obs dim: {orig_obs_dim}")
print(f"Student obs dim (with one-hot task id): {student_obs_dim}")



In [None]:

methods = ["D1_KL", "D2_MSE", "D3_WKL", "D4_KL_LATENT"]
method = methods[3]
# --- create student ---
student = GaussianStudentPolicy(student_obs_dim, act_dim)
projector = None  # used only for D4

student, projector = training_loop_with_replay(
    student=student,
    projector=projector,
    method=method,
    TASK_SEQUENCE=TASK_SEQUENCE,
    max_replay_per_task=60_000,
    replay_ratio=0.20,
    anchor_coeff=1e-4
)


In [None]:
save_student(
    student=student,
    projector=projector,
    obs_dim=student_obs_dim,
    act_dim=act_dim,
    method=method,
    out_path=f"./students/student_after_all_tasks_{method}.pt",
    extra={
        "n_tasks": N_TASKS,
        "task_names": [cfg["name"] for cfg in TASK_SEQUENCE],
    }
)



In [None]:
student, _ = load_student("./students/student_after_all_tasks_D4_KL_LATENT.pt", device="cpu")
final_results = evaluate_student_on_all_tasks(student, TASK_SEQUENCE, seed=0, n_episodes=10)
print (final_results)