In [None]:
!git clone https://github.com/tam4x/rl_project_25.git
!pip install -r rl_project_25/requirements.txt
import sys
sys.path.append("/content/rl_project_25")

##### Import Libaries

In [3]:
from typing import Callable, Dict, Any, Tuple, List
import os
import numpy as np
from stable_baselines3.common.evaluation import evaluate_policy
from src.teacher import *
from src.memory import *
from src.distillation import *
from src.student import *
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

##### Create the tasks

In [4]:
halfcheetah_tasks = [
    Task("HC_WALK_v1.0",  lambda: task_halfcheetah_target_velocity( 1.0, seed=0)),
    Task("HC_RUN_v6.0",   lambda: task_halfcheetah_target_velocity( 6.0, seed=0)),
    Task("HC_BACK_v-1.0", lambda: task_halfcheetah_target_velocity(-1.0, seed=0)),
]

walker_tasks = [
    Task("W_WALK_v1.0",  lambda: task_walker2d_target_velocity( 1.0, seed=1)),
    Task("W_RUN_v3.5",   lambda: task_walker2d_target_velocity( 3.5, seed=1)),
    Task("W_BACK_v-1.0", lambda: task_walker2d_target_velocity(-1.0, seed=1)),
    # Optional (add later):
    # Task("W_JUMP", lambda: task_walker2d_jump(seed=1, baseline_height=1.25, beta=5.0)),
]

task_list = [
    {"env_fn": lambda: task_halfcheetah_target_velocity( 1.0, seed=0), 
     "model_path": "./teachers/HC_WALK_v1.0_SAC.zip", 
     "vec_path": "./teachers/HC_WALK_v1.0_SAC_vecnormalize.pkl", "task": 1.0},
    {"env_fn": lambda: task_halfcheetah_target_velocity( 6.0, seed=0), 
     "model_path": "./teachers/HC_RUN_v6.0_SAC.zip", 
     "vec_path": "./teachers/HC_RUN_v6.0_SAC_vecnormalize.pkl", "task": 6.0},
    {"env_fn": lambda: task_halfcheetah_target_velocity(-1.0, seed=0),
    "model_path": "./teachers/HC_BACK_v-1.0_SAC.zip", 
    "vec_path": "./teachers/HC_BACK_v-1.0_SAC_vecnormalize.pkl", "task": -1.0},
]


##### Train the teacher

In [None]:
results = []
for t in halfcheetah_tasks:
    res = train_teacher_for_task(
        task=t,
        algo="SAC",
        total_timesteps=1_000_000,
        seed=0,
        normalize_obs=True,
        out_dir="./teachers",
        log_dir="./tb_logs",
    )
    results.append(res)

results


##### Try out Teacher and visualize

In [None]:

reward = []
for t in task_list:
    mean_r, std_r = eval_teacher(
        model_path=t["model_path"],
        vec_path=t["vec_path"],
        make_env_fn=t["env_fn"],
        n_eval_episodes=10,
        seed=0
    )
    print(f"Task {t['task']}: mean reward = {mean_r} +/- {std_r}")
    reward.append(mean_r)

    #test_teacher_render(model_path=t["model_path"], vec_path=t["vec_path"], task=t["task"])

##### Load Teacher for Memory Creation

In [49]:
MEM_DIR = "./memory_sac"
os.makedirs(MEM_DIR, exist_ok=True)

all_mem_paths = []

N_TASKS = 3                 # total number of tasks
TASK_ID_ENCODING = "onehot"  # "onehot" or "integer"

for task_id, t in enumerate(halfcheetah_tasks):
    assert task_id < N_TASKS, "More tasks than expected!"

    attr = task_list[task_id]

    # Load teacher + env
    model, venv = load_sac_teacher(
        task=t,
        model_path=attr["model_path"],
        vec_path=attr["vec_path"],
        seed=0,
    )

    # Collect memory WITH task-id in obs
    mem = collect_memory_from_sac_teacher(
        model=model,
        venv=venv,
        task_name=t.name,
        task_id=task_id,              
        n_tasks=N_TASKS,              
        task_id_encoding=TASK_ID_ENCODING,
        n_steps=50_000,
        deterministic_action=False,   
        store_actions=True,
        seed=123,
    )

    out_path = os.path.join(
        MEM_DIR, f"{t.name}_task{task_id}_SAC_memory.npz"
    )
    save_memory_npz(mem, out_path)

    all_mem_paths.append(out_path)
    venv.close()

all_mem_paths


  logger.deprecation(


Saved memory: ./memory_sac\HC_WALK_v1.0_task0_SAC_memory.npz
Saved memory: ./memory_sac\HC_RUN_v6.0_task1_SAC_memory.npz
Saved memory: ./memory_sac\HC_BACK_v-1.0_task2_SAC_memory.npz


['./memory_sac\\HC_WALK_v1.0_task0_SAC_memory.npz',
 './memory_sac\\HC_RUN_v6.0_task1_SAC_memory.npz',
 './memory_sac\\HC_BACK_v-1.0_task2_SAC_memory.npz']

##### Training Run for Distillation

In [53]:
import os

MEM_DIR = "./memory_sac"

# halfcheetah_tasks and task_list must be aligned (same order)
assert len(task_list) == len(halfcheetah_tasks), "task_list and halfcheetah_tasks must have same length"

TASK_SEQUENCE = []
for i, (t, attr) in enumerate(zip(halfcheetah_tasks, task_list)):
    TASK_SEQUENCE.append({
        "name": t.name,                         
        "env_id": None,                         
        "env_fn": attr["env_fn"],               
        "npz_path": MEM_DIR + f"/{t.name}_task{i}_SAC_memory.npz",
        "model_path": attr["model_path"],
        "vec_path": attr["vec_path"],
        "task_value": attr["task"],             
        "task_id": i,                           
        "n_tasks": len(halfcheetah_tasks),      
    })
print(TASK_SEQUENCE)

[{'name': 'HC_WALK_v1.0', 'env_id': None, 'env_fn': <function <lambda> at 0x000001D68BEF29E0>, 'npz_path': './memory_sac/HC_WALK_v1.0_task0_SAC_memory.npz', 'model_path': './teachers/HC_WALK_v1.0_SAC.zip', 'vec_path': './teachers/HC_WALK_v1.0_SAC_vecnormalize.pkl', 'task_value': 1.0, 'task_id': 0, 'n_tasks': 3}, {'name': 'HC_RUN_v6.0', 'env_id': None, 'env_fn': <function <lambda> at 0x000001D68BEF2A70>, 'npz_path': './memory_sac/HC_RUN_v6.0_task1_SAC_memory.npz', 'model_path': './teachers/HC_RUN_v6.0_SAC.zip', 'vec_path': './teachers/HC_RUN_v6.0_SAC_vecnormalize.pkl', 'task_value': 6.0, 'task_id': 1, 'n_tasks': 3}, {'name': 'HC_BACK_v-1.0', 'env_id': None, 'env_fn': <function <lambda> at 0x000001D68BEF2B00>, 'npz_path': './memory_sac/HC_BACK_v-1.0_task2_SAC_memory.npz', 'model_path': './teachers/HC_BACK_v-1.0_SAC.zip', 'vec_path': './teachers/HC_BACK_v-1.0_SAC_vecnormalize.pkl', 'task_value': -1.0, 'task_id': 2, 'n_tasks': 3}]


In [54]:
N_TASKS = len(TASK_SEQUENCE)

tmp = np.load(TASK_SEQUENCE[0]["npz_path"], allow_pickle=True)
act_dim = tmp["mu"].shape[1]

first_cfg = TASK_SEQUENCE[0]

task0 = Task(first_cfg["name"], first_cfg["env_fn"])

venv_tmp = build_vec_env(task0, seed=0, normalize_obs=False)
obs0 = venv_tmp.reset()
orig_obs_dim = obs0.shape[1]
venv_tmp.close()

student_obs_dim = orig_obs_dim + N_TASKS

print(f"Original env obs dim: {orig_obs_dim}")
print(f"Student obs dim (with one-hot task id): {student_obs_dim}")



Original env obs dim: 17
Student obs dim (with one-hot task id): 20


In [64]:

methods = ["D1_KL", "D2_MSE", "D3_WKL", "D4_KL_LATENT"]
method = methods[3]
# --- create student ---
student = GaussianStudentPolicy(student_obs_dim, act_dim)
projector = None
epochs = 40
student, projector = training_loop_with_replay(
    student=student,
    projector=projector,
    method=method,
    TASK_SEQUENCE=TASK_SEQUENCE,
    epochs=epochs,
    max_replay_per_task=15_000,
    replay_ratio=0.20,
    anchor_coeff=1e-4
)



 Training task 1/3: HC_WALK_v1.0
Epoch 01 | D4_KL_LATENT loss: 3.8123 | replay=0.20 | anchor=0.0001 (lambda_feat=0.2)
Epoch 10 | D4_KL_LATENT loss: 0.3793 | replay=0.20 | anchor=0.0001 (lambda_feat=0.2)
Epoch 20 | D4_KL_LATENT loss: 0.2527 | replay=0.20 | anchor=0.0001 (lambda_feat=0.2)
Epoch 30 | D4_KL_LATENT loss: 0.2040 | replay=0.20 | anchor=0.0001 (lambda_feat=0.2)
Epoch 40 | D4_KL_LATENT loss: 0.1782 | replay=0.20 | anchor=0.0001 (lambda_feat=0.2)
--> Eval after task 1: mean return = -162.24 +/- 394.90

 Training task 2/3: HC_RUN_v6.0
Epoch 01 | D4_KL_LATENT loss: 4.6870 | replay=0.20 | anchor=0.0001 (lambda_feat=0.2)
Epoch 10 | D4_KL_LATENT loss: 0.8601 | replay=0.20 | anchor=0.0001 (lambda_feat=0.2)
Epoch 20 | D4_KL_LATENT loss: 0.6145 | replay=0.20 | anchor=0.0001 (lambda_feat=0.2)
Epoch 30 | D4_KL_LATENT loss: 0.5105 | replay=0.20 | anchor=0.0001 (lambda_feat=0.2)
Epoch 40 | D4_KL_LATENT loss: 0.4484 | replay=0.20 | anchor=0.0001 (lambda_feat=0.2)
--> Eval after task 2: mean

##### Save Student and Evaluate

In [65]:
save_student(
    student=student,
    projector=projector,
    obs_dim=student_obs_dim,
    act_dim=act_dim,
    method=method,
    out_path=f"./students/student_after_all_tasks_{method}_{epochs}.pt",
    extra={
        "n_tasks": N_TASKS,
        "task_names": [cfg["name"] for cfg in TASK_SEQUENCE],
    }
)



Saved student to ./students/student_after_all_tasks_D4_KL_LATENT_40.pt


In [66]:
student, _ = load_student(f"./students/student_after_all_tasks_{method}_{epochs}.pt", device="cpu")
final_results = evaluate_student_on_all_tasks(student, TASK_SEQUENCE, seed=0, n_episodes=10)
print (final_results)

[FINAL EVAL] HC_WALK_v1.0: -34.74 +/- 0.69
[FINAL EVAL] HC_RUN_v6.0: -399.82 +/- 9.56
[FINAL EVAL] HC_BACK_v-1.0: -1055.76 +/- 513.06
[('HC_WALK_v1.0', -34.741188049316406, 0.6855669021606445), ('HC_RUN_v6.0', -399.82366943359375, 9.55794620513916), ('HC_BACK_v-1.0', -1055.757080078125, 513.06396484375)]
