In [2]:
import gymnasium_robotics
import gymnasium as gym
from stable_baselines3 import SAC, HerReplayBuffer
from stable_baselines3.common.callbacks import EvalCallback, CallbackList, BaseCallback
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
import os
import numpy as np
import matplotlib.pyplot as plt

# --- Callback: Stop training if success > 90% ---
class StopTrainingOnSuccessRate(BaseCallback):
    def __init__(self, success_threshold=0.90, eval_freq=5000, verbose=1):
        super().__init__(verbose)
        self.success_threshold = success_threshold
        self.eval_freq = eval_freq

    def _on_step(self) -> bool:
        if self.num_timesteps % self.eval_freq == 0:
            successes = []
            for _ in range(10):
                obs = self.training_env.reset()
                done = False
                while not done:
                    action, _ = self.model.predict(obs, deterministic=True)
                    obs, reward, done, infos = self.training_env.step(action)
                    done = done[0]
                    infos = infos[0]
                successes.append(infos['is_success'])
            success_rate = np.mean(successes)
            if self.verbose:
                print(f"✅ [Success Monitor] Success rate: {success_rate * 100:.2f}%")
            if success_rate >= self.success_threshold:
                print(f"🏆 [Early Stop] Success rate {success_rate * 100:.2f}% ≥ {self.success_threshold * 100:.2f}%")
                return False
        return True

# --- Callback: Save checkpoint if success > 80% ---
class CheckpointOnSuccessRate(BaseCallback):
    def __init__(self, save_path, success_threshold=0.80, eval_freq=5000, verbose=1):
        super().__init__(verbose)
        self.save_path = save_path
        self.success_threshold = success_threshold
        self.eval_freq = eval_freq
        self.saved = False

    def _on_step(self) -> bool:
        if self.num_timesteps % self.eval_freq == 0 and not self.saved:
            successes = []
            for _ in range(10):
                obs = self.training_env.reset()
                done = False
                while not done:
                    action, _ = self.model.predict(obs, deterministic=True)
                    obs, reward, done, infos = self.training_env.step(action)
                    done = done[0]
                    infos = infos[0]
                successes.append(infos['is_success'])
            success_rate = np.mean(successes)
            if self.verbose:
                print(f"📦 [Checkpoint Check] Success rate: {success_rate * 100:.2f}%")
            if success_rate >= self.success_threshold:
                self.saved = True
                self.model.save(os.path.join(self.save_path, "checkpoint_model"))
                self.training_env.save(os.path.join(self.save_path, "checkpoint_vecnormalize.pkl"))
                print(f"✅ [Checkpoint Saved] at {success_rate * 100:.2f}% success rate")
        return True

# --- Configuration ---
env_id = "FetchReach-v3"
save_dir = f"./her_sac_{env_id}_results/"
best_model_save_path = os.path.join(save_dir, "best_model")
final_model_save_path = os.path.join(save_dir, f"her_sac_{env_id}_final")

os.makedirs(save_dir, exist_ok=True)

# --- Env Setup ---
env = gym.make(env_id, reward_type='dense')
env = DummyVecEnv([lambda: env])
env = VecNormalize(env, norm_obs=True, norm_reward=False)

# --- SAC + HER ---
model = SAC(
    "MultiInputPolicy",
    env,
    replay_buffer_class=HerReplayBuffer,
    replay_buffer_kwargs=dict(
        n_sampled_goal=8,
        goal_selection_strategy='future',
    ),
    verbose=1,
    buffer_size=int(1e6),
    learning_starts=1000,
    batch_size=512,
    learning_rate=1e-4,
    gamma=0.99,
    tau=0.005,
)

# --- Callbacks ---
eval_callback = EvalCallback(
    env,
    best_model_save_path=best_model_save_path,
    log_path=save_dir,
    eval_freq=5000,
    n_eval_episodes=10,
    deterministic=True,
    render=False,
    verbose=1,
)

stop_callback = StopTrainingOnSuccessRate(success_threshold=0.90, eval_freq=5000, verbose=1)
checkpoint_callback = CheckpointOnSuccessRate(save_path=save_dir, success_threshold=0.80, eval_freq=5000, verbose=1)

callback = CallbackList([eval_callback, checkpoint_callback, stop_callback])

# --- Training ---
print(f"🚀 Starting training SAC + HER on {env_id}...")
model.learn(
    total_timesteps=1_000_000,
    log_interval=50,
    callback=callback
)

# --- Save final model + VecNormalize ---
model.save(final_model_save_path)
env.save(os.path.join(save_dir, "vecnormalize.pkl"))
print("✅ Final model and VecNormalize saved.")

# --- Final 1000-episode eval ---
print("\n🎯 Running final 1000-episode evaluation...")
eval_env = gym.make(env_id, reward_type='dense')
eval_env = DummyVecEnv([lambda: eval_env])
eval_env = VecNormalize.load(os.path.join(save_dir, "vecnormalize.pkl"), eval_env)
eval_env.training = False
eval_env.norm_reward = False

model = SAC.load(os.path.join(best_model_save_path, "best_model.zip"), env=eval_env)

successes = []
for _ in range(1000):
    obs = eval_env.reset()
    done = False
    while not done:
        action, _ = model.predict(obs, deterministic=True)
        obs, reward, done, infos = eval_env.step(action)
        done = done[0]
        infos = infos[0]
    successes.append(infos['is_success'])

final_success = np.mean(successes) * 100.0
print(f"\n✅ Final Evaluation Success Rate: {final_success:.2f}%")
if final_success >= 90.0:
    print("🏆 Problem Solved!")
else:
    print("⚡ Not Fully Solved")

# --- Plot reward + success ---
print("\n📈 Plotting training curves...")
log_file = os.path.join(save_dir, "evaluations.npz")
if not os.path.exists(log_file):
    raise FileNotFoundError("No EvalCallback log file found!")

data = np.load(log_file)
timesteps = data["timesteps"]
results = data["results"]
mean_rewards = results.mean(axis=1)
success_rates = (results >= -1e-3).mean(axis=1) * 100

plt.figure(figsize=(10, 4))
plt.plot(timesteps, mean_rewards, label="Mean Eval Reward", color="blue")
plt.xlabel("Timesteps")
plt.ylabel("Mean Reward")
plt.title("Reward vs Timesteps")
plt.grid(True)
plt.tight_layout()
plt.savefig(os.path.join(save_dir, "reward_plot.png"))
plt.close()

plt.figure(figsize=(10, 4))
plt.plot(timesteps, success_rates, label="Success Rate (%)", color="green")
plt.xlabel("Timesteps")
plt.ylabel("Success Rate (%)")
plt.title("Success Rate vs Timesteps")
plt.grid(True)
plt.tight_layout()
plt.savefig(os.path.join(save_dir, "success_plot.png"))
plt.close()

print("✅ Plots saved: reward_plot.png & success_plot.png")


Using cuda device
🚀 Starting training SAC + HER on FetchReach-v3...
---------------------------------
| rollout/           |          |
|    success_rate    | 0        |
| time/              |          |
|    episodes        | 50       |
|    fps             | 186      |
|    time_elapsed    | 13       |
|    total_timesteps | 2500     |
| train/             |          |
|    actor_loss      | -17.9    |
|    critic_loss     | 0.11     |
|    ent_coef        | 0.861    |
|    ent_coef_loss   | -0.996   |
|    learning_rate   | 0.0001   |
|    n_updates       | 1499     |
---------------------------------
Eval num_timesteps=5000, episode_reward=-17.72 +/- 2.62
Episode length: 50.00 +/- 0.00
Success rate: 0.00%
---------------------------------
| eval/              |          |
|    mean_ep_length  | 50       |
|    mean_reward     | -17.7    |
|    success_rate    | 0.0      |
| time/              |          |
|    total_timesteps | 5000     |
| train/             |          |
|    acto

FINE TUNING: 

In [6]:
import os
import gymnasium as gym
import numpy as np
import matplotlib.pyplot as plt
from stable_baselines3 import SAC
from stable_baselines3.her.her_replay_buffer import HerReplayBuffer
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
from stable_baselines3.common.callbacks import EvalCallback, StopTrainingOnRewardThreshold

# --- Paths ---
checkpoint_dir = "./her_sac_FetchReach-v3_results/"
model_path = os.path.join(checkpoint_dir, "checkpoint_model.zip")  # ✅ 70% checkpoint
vecnorm_path = os.path.join(checkpoint_dir, "checkpoint_vecnormalize.pkl")
save_path = os.path.join(checkpoint_dir, "finetuned_model_stage3.zip")

# --- Reload environment ---
env = gym.make("FetchReach-v3", reward_type="dense")
env = DummyVecEnv([lambda: env])
env = VecNormalize.load(vecnorm_path, env)
env.training = True
env.norm_reward = False

# --- Load model ---
model = SAC.load(model_path, env=env)

# --- Conservative hyperparameters ---
model.learning_rate = 2e-5
model.learning_starts = 5000
model.actor.optimizer.param_groups[0]['lr'] = 2e-5
model.critic.optimizer.param_groups[0]['lr'] = 2e-5
model.batch_size = 1024

# ✅ Reset HER buffer with reduced goal sampling
model.replay_buffer = HerReplayBuffer(
    buffer_size=model.buffer_size,
    observation_space=model.observation_space,
    action_space=model.action_space,
    env=env,
    device=model.device,
    n_envs=1,
    optimize_memory_usage=False,
    handle_timeout_termination=True,
    n_sampled_goal=4,
    goal_selection_strategy="future"
)

# ✅ Prefill with 1 episode
print("📦 Prefilling HER buffer with 1 episode...")
obs = env.reset()
done = [False]
step = 0
while not done[0] and step < 60:
    action, _ = model.predict(obs, deterministic=False)
    next_obs, reward, done, info = env.step(action)
    model.replay_buffer.add(obs, next_obs, action, reward, done, info)
    obs = next_obs
    step += 1

# --- Callbacks ---
stop_callback = StopTrainingOnRewardThreshold(reward_threshold=-5.0, verbose=1)
eval_callback = EvalCallback(
    env,
    best_model_save_path=os.path.join(checkpoint_dir, "finetuned_best_model_stage3"),
    log_path=checkpoint_dir,
    eval_freq=5000,
    n_eval_episodes=10,
    deterministic=True,
    render=False,
    callback_on_new_best=stop_callback,
    verbose=1
)

# --- Start fine-tuning ---
print("🚀 Starting Stage 3 Conservative Fine-Tuning...")
model.learn(
    total_timesteps=100_000,
    log_interval=50,
    callback=eval_callback,
    reset_num_timesteps=False
)

# --- Save model and vecnormalize ---
model.save(save_path)
env.save(os.path.join(checkpoint_dir, "vecnormalize_stage3.pkl"))
print(f"✅ Stage 3 fine-tuning complete. Model saved to: {save_path}")

# --- Plotting with true success rates ---
print("\n📈 Plotting updated reward + success rate curves...")

eval_file = os.path.join(checkpoint_dir, "evaluations.npz")
success_file = os.path.join(checkpoint_dir, "evaluations", "ep_success.npy")

if not os.path.exists(eval_file):
    raise FileNotFoundError("❌ Missing evaluations.npz")

data = np.load(eval_file)
timesteps = data["timesteps"]
results = data["results"]
mean_rewards = results.mean(axis=1)

if os.path.exists(success_file):
    successes = np.load(success_file)
    success_rates = successes.mean(axis=1) * 100
else:
    print("⚠️ Falling back to reward threshold for success plot.")
    success_rates = (results >= -1e-3).mean(axis=1) * 100

# Plot reward
plt.figure(figsize=(10, 4))
plt.plot(timesteps, mean_rewards, label="Mean Eval Reward", color="blue")
plt.xlabel("Timesteps")
plt.ylabel("Mean Reward")
plt.title("Fine-Tuning Stage 3: Reward vs Timesteps")
plt.grid(True)
plt.tight_layout()
plt.savefig(os.path.join(checkpoint_dir, "finetune_stage3_reward_plot.png"))
plt.close()

# Plot success
plt.figure(figsize=(10, 4))
plt.plot(timesteps, success_rates, label="Success Rate", color="green")
plt.xlabel("Timesteps")
plt.ylabel("Success Rate (%)")
plt.title("Fine-Tuning Stage 3: Success Rate vs Timesteps")
plt.grid(True)
plt.tight_layout()
plt.savefig(os.path.join(checkpoint_dir, "finetune_stage3_success_plot.png"))
plt.close()

print("✅ Plots saved:")
print("   🔹 finetune_stage3_reward_plot.png")
print("   🔹 finetune_stage3_success_plot.png")


📦 Prefilling HER buffer with 1 episode...
🚀 Starting Stage 3 Conservative Fine-Tuning...
---------------------------------
| rollout/           |          |
|    success_rate    | 0.38     |
| time/              |          |
|    episodes        | 6200     |
|    fps             | 121      |
|    time_elapsed    | 0        |
|    total_timesteps | 310050   |
| train/             |          |
|    actor_loss      | 24.4     |
|    critic_loss     | 10.7     |
|    ent_coef        | 0.00108  |
|    ent_coef_loss   | 181      |
|    learning_rate   | 0.0001   |
|    n_updates       | 309048   |
---------------------------------
---------------------------------
| rollout/           |          |
|    success_rate    | 0.28     |
| time/              |          |
|    episodes        | 6250     |
|    fps             | 121      |
|    time_elapsed    | 20       |
|    total_timesteps | 312550   |
| train/             |          |
|    actor_loss      | 2.7      |
|    critic_loss     | 0.08

RESTART FROM SCRATCH TRAINING: 

In [7]:
import os
import gymnasium as gym
import gymnasium_robotics
import numpy as np
import matplotlib.pyplot as plt
from stable_baselines3 import SAC, HerReplayBuffer
from stable_baselines3.common.callbacks import EvalCallback, CallbackList, BaseCallback
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
from stable_baselines3.common.logger import configure

# === Custom Callbacks ===
class StopTrainingOnSuccessRate(BaseCallback):
    def __init__(self, success_threshold=0.90, eval_freq=2000, verbose=1):
        super().__init__(verbose)
        self.success_threshold = success_threshold
        self.eval_freq = eval_freq

    def _on_step(self) -> bool:
        if self.num_timesteps % self.eval_freq == 0:
            successes = []
            for _ in range(10):
                obs = self.training_env.reset()
                done = False
                while not done:
                    action, _ = self.model.predict(obs, deterministic=True)
                    obs, reward, done, infos = self.training_env.step(action)
                    done = done[0]
                    infos = infos[0]
                successes.append(infos['is_success'])
            success_rate = np.mean(successes)
            self.logger.record("custom/success_rate_eval", success_rate)
            if self.verbose:
                print(f"✅ [Monitor] Success rate: {success_rate * 100:.2f}%")
            if success_rate >= self.success_threshold:
                print(f"🏆 [Early Stop] Target success rate {success_rate * 100:.2f}% reached.")
                return False
        return True

class CheckpointOnSuccessRate(BaseCallback):
    def __init__(self, save_path, success_threshold=0.80, eval_freq=2000, verbose=1):
        super().__init__(verbose)
        self.save_path = save_path
        self.success_threshold = success_threshold
        self.eval_freq = eval_freq
        self.stabilized = False

    def _on_step(self) -> bool:
        if self.num_timesteps % self.eval_freq == 0:
            successes = []
            for _ in range(10):
                obs = self.training_env.reset()
                done = False
                while not done:
                    action, _ = self.model.predict(obs, deterministic=True)
                    obs, reward, done, infos = self.training_env.step(action)
                    done = done[0]
                    infos = infos[0]
                successes.append(infos['is_success'])
            success_rate = np.mean(successes)
            self.logger.record("custom/success_rate_checkpoint", success_rate)
            if success_rate >= self.success_threshold and not self.stabilized:
                print(f"📦 [Checkpoint] Saving model at {success_rate * 100:.2f}% success rate.")
                self.model.save(os.path.join(self.save_path, "checkpoint_model"))
                self.training_env.save(os.path.join(self.save_path, "checkpoint_vecnormalize.pkl"))
                # Reduce entropy and learning rate to stabilize
                self.model.ent_coef = 0.001
                self.model.lr_schedule = lambda _: 5e-5
                self.stabilized = True
        return True

# === Config ===
env_id = "FetchReach-v3"
save_dir = f"./tensorboard_sac_her_{env_id}/"
best_model_path = os.path.join(save_dir, "best_model")
final_model_path = os.path.join(save_dir, "final_model")

os.makedirs(save_dir, exist_ok=True)

# === Env Setup ===
env = gym.make(env_id, reward_type='dense')
env = DummyVecEnv([lambda: env])
env = VecNormalize(env, norm_obs=True, norm_reward=False)

# === SAC + HER ===
model = SAC(
    "MultiInputPolicy",
    env,
    replay_buffer_class=HerReplayBuffer,
    replay_buffer_kwargs=dict(
        n_sampled_goal=16,
        goal_selection_strategy='future',
    ),
    verbose=1,
    buffer_size=int(1e6),
    learning_starts=500,
    batch_size=256,
    learning_rate=1e-4,
    gamma=0.98,
    tau=0.005,
    tensorboard_log=os.path.join(save_dir, "tb_log")
)

# === Set up TensorBoard logger ===
new_logger = configure(os.path.join(save_dir, "tb_log"), ["stdout", "tensorboard"])
model.set_logger(new_logger)

# === Callbacks ===
eval_callback = EvalCallback(
    env,
    best_model_save_path=best_model_path,
    log_path=save_dir,
    eval_freq=2000,
    n_eval_episodes=10,
    deterministic=True,
    render=False,
    verbose=1
)

checkpoint_callback = CheckpointOnSuccessRate(
    save_path=save_dir,
    success_threshold=0.80,
    eval_freq=2000,
    verbose=1
)

stop_callback = StopTrainingOnSuccessRate(
    success_threshold=0.90,
    eval_freq=2000,
    verbose=1
)

callback = CallbackList([eval_callback, checkpoint_callback, stop_callback])

# === Train ===
print(f"🚀 Training SAC + HER on {env_id} with TensorBoard...")
model.learn(total_timesteps=1_000_000, log_interval=50, callback=callback)

# === Save ===
model.save(final_model_path)
env.save(os.path.join(save_dir, "vecnormalize.pkl"))
print("✅ Final model and VecNormalize saved.")

# === TensorBoard usage ===
print("\n📊 To view TensorBoard logs:")
print(f"tensorboard --logdir {os.path.join(save_dir, 'tb_log')}")



Using cuda device
Logging to ./tensorboard_sac_her_FetchReach-v3/tb_log
🚀 Training SAC + HER on FetchReach-v3 with TensorBoard...
Eval num_timesteps=2000, episode_reward=-22.50 +/- 5.09
Episode length: 50.00 +/- 0.00
Success rate: 0.00%
---------------------------------
| eval/              |          |
|    mean_ep_length  | 50       |
|    mean_reward     | -22.5    |
|    success_rate    | 0.0      |
| time/              |          |
|    total_timesteps | 2000     |
| train/             |          |
|    actor_loss      | -18.2    |
|    critic_loss     | 0.133    |
|    ent_coef        | 0.861    |
|    ent_coef_loss   | -0.995   |
|    learning_rate   | 0.0001   |
|    n_updates       | 1499     |
---------------------------------
New best mean reward!
✅ [Monitor] Success rate: 0.00%
-----------------------------------------
| custom/                    |          |
|    success_rate_checkpoint | 0.0      |
|    success_rate_eval       | 0.0      |
| rollout/                   | 

FINE TUNE NEW MODEL: 

In [14]:
import os
import gymnasium as gym
import gymnasium_robotics # noqa - Registering FetchReach environments
import numpy as np
# import matplotlib.pyplot as plt # Not used in this script, but can be useful for custom plotting
from stable_baselines3 import SAC, HerReplayBuffer
from stable_baselines3.common.callbacks import EvalCallback, CallbackList, BaseCallback
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
from stable_baselines3.common.logger import configure
from stable_baselines3.common.monitor import Monitor # To wrap the env for EvalCallback and custom callbacks

# === Custom Callbacks (Re-used from your original script) ===
class StopTrainingOnSuccessRate(BaseCallback):
    """
    Stop training when a threshold success rate is reached in evaluation.
    Uses the model's training environment for evaluation.

    :param success_threshold: The target success rate (e.g., 0.9 for 90%).
    :param eval_freq: How often to evaluate the success rate (in timesteps).
    :param verbose: Verbosity level.
    :param n_eval_episodes: Number of episodes for evaluation.
    """
    def __init__(self, success_threshold=0.90, eval_freq=2000, verbose=1, n_eval_episodes=10):
        super().__init__(verbose)
        self.success_threshold = success_threshold
        self.eval_freq = eval_freq
        self.n_eval_episodes = n_eval_episodes

    def _on_step(self) -> bool:
        # Check if it's time to evaluate
        if self.num_timesteps % self.eval_freq == 0:
            successes = []
            # Get the vectorized training environment from the model
            eval_env = self.model.get_env()

            if eval_env is None:
                if self.verbose > 0:
                    print("⚠️ [Monitor - Stop CB] Evaluation environment not found in model.")
                return True # Continue training, cannot evaluate

            for _ in range(self.n_eval_episodes):
                # Reset the environment. For VecEnv, reset returns a list/dict of obs.
                # We expect a single environment in DummyVecEnv, so obs_list[0]
                obs_list = eval_env.reset()
                obs = obs_list[0] if isinstance(obs_list, list) else obs_list
                done_list = [False] # For VecEnv, done is a list

                while not done_list[0]:
                    action, _ = self.model.predict(obs, deterministic=True)
                    # Step in the vectorized environment
                    next_obs_list, _, done_list, infos_list = eval_env.step(action)
                    obs = next_obs_list[0] if isinstance(next_obs_list, list) else next_obs_list

                # Extract 'is_success' from the info dictionary
                # Monitor wrapper places it in 'final_info' if truncated, or directly if terminated
                info = infos_list[0]
                is_success = info.get('is_success')
                if is_success is None and 'final_info' in info and info['final_info'] is not None:
                    is_success = info['final_info'].get('is_success')
                
                successes.append(is_success if is_success is not None else 0.0)


            if successes: # Ensure successes list is not empty
                success_rate = np.mean(successes)
                self.logger.record("custom/success_rate_eval_stop", success_rate)
                if self.verbose > 0:
                    print(f"✅ [Monitor - Stop CB] Success rate: {success_rate * 100:.2f}% (Timesteps: {self.num_timesteps})")
                if success_rate >= self.success_threshold:
                    print(f"🏆 [Early Stop CB] Target success rate {self.success_threshold*100:.2f}% reached ({success_rate * 100:.2f}%). Stopping training.")
                    return False # Stop training
            else:
                if self.verbose > 0:
                    print(f"⚠️ [Monitor - Stop CB] No success info recorded in {self.n_eval_episodes} episodes. Check 'is_success' key and Monitor wrapper.")

        return True # Continue training

class CheckpointOnSuccessRate(BaseCallback):
    """
    Save a checkpoint of the model when a threshold success rate is reached.
    Uses the model's training environment for evaluation.

    :param save_path: Path to save the checkpoint.
    :param success_threshold: The success rate threshold to trigger checkpointing.
    :param eval_freq: How often to evaluate the success rate (in timesteps).
    :param verbose: Verbosity level.
    :param n_eval_episodes: Number of episodes to run for evaluation.
    """
    def __init__(self, save_path, success_threshold=0.80, eval_freq=2000, verbose=1, n_eval_episodes=10):
        super().__init__(verbose)
        self.save_path = save_path
        self.success_threshold = success_threshold
        self.eval_freq = eval_freq
        self.n_eval_episodes = n_eval_episodes
        self.last_saved_timestep_checkpoint = -1 # Avoid saving multiple times for the same threshold if success fluctuates

    def _on_step(self) -> bool:
        if self.num_timesteps % self.eval_freq == 0:
            successes = []
            eval_env = self.model.get_env()

            if eval_env is None:
                if self.verbose > 0:
                    print("⚠️ [Monitor - Checkpoint CB] Evaluation environment not found in model.")
                return True

            for _ in range(self.n_eval_episodes):
                obs_list = eval_env.reset()
                obs = obs_list[0] if isinstance(obs_list, list) else obs_list
                done_list = [False]
                while not done_list[0]:
                    action, _ = self.model.predict(obs, deterministic=True)
                    next_obs_list, _, done_list, infos_list = eval_env.step(action)
                    obs = next_obs_list[0] if isinstance(next_obs_list, list) else next_obs_list
                
                info = infos_list[0]
                is_success = info.get('is_success')
                if is_success is None and 'final_info' in info and info['final_info'] is not None:
                    is_success = info['final_info'].get('is_success')
                successes.append(is_success if is_success is not None else 0.0)

            if successes:
                success_rate = np.mean(successes)
                self.logger.record("custom/success_rate_checkpoint", success_rate)
                if self.verbose > 0:
                     print(f"🧐 [Monitor - Checkpoint CB] Success rate: {success_rate * 100:.2f}% (Timesteps: {self.num_timesteps})")

                # Check if success rate is above threshold AND we haven't saved for this specific success level recently
                # This simple check avoids re-saving if it hovers around the threshold.
                # A more robust 'stabilized' flag might involve multiple consecutive high scores.
                if success_rate >= self.success_threshold and self.num_timesteps > self.last_saved_timestep_checkpoint:
                    checkpoint_model_path = os.path.join(self.save_path, f"checkpoint_model_ft_succ{int(success_rate*100)}_{self.num_timesteps}")
                    checkpoint_vecnorm_path = os.path.join(self.save_path, f"checkpoint_vecnormalize_ft_succ{int(success_rate*100)}_{self.num_timesteps}.pkl")

                    print(f"📦 [Checkpoint CB] Saving model at {success_rate * 100:.2f}% success rate to {checkpoint_model_path}")
                    self.model.save(checkpoint_model_path)
                    self.last_saved_timestep_checkpoint = self.num_timesteps # Update last saved timestep

                    # Save VecNormalize stats if the training environment is VecNormalize
                    # self.training_env is an alias for self.model.get_env() in callbacks
                    if isinstance(self.training_env, VecNormalize):
                        self.training_env.save(checkpoint_vecnorm_path)
                        if self.verbose > 0:
                            print(f"📦 [Checkpoint CB] Saved VecNormalize stats to {checkpoint_vecnorm_path}")
            else:
                if self.verbose > 0:
                    print(f"⚠️ [Monitor - Checkpoint CB] No success info recorded in {self.n_eval_episodes} episodes for checkpointing.")
        return True


# === Config for Fine-Tuning ===
env_id = "FetchReach-v3"
original_save_dir = f"./tensorboard_sac_her_{env_id}/" # Path to original model and stats
# Choose which model to load. Ensure this file exists.
# model_to_load_name = "final_model.zip"
# model_to_load_name = "best_model.zip"
model_to_load_name = "checkpoint_model.zip" # From your previous training

# Path for saving the fine-tuned model and its logs
finetune_save_dir = f"./tensorboard_sac_her_{env_id}_finetuned/"
finetune_best_model_path = os.path.join(finetune_save_dir, "best_finetuned_model")
finetune_final_model_path = os.path.join(finetune_save_dir, "final_finetuned_model")

os.makedirs(finetune_save_dir, exist_ok=True)

# === Env Setup for Fine-Tuning ===
def make_env():
    """Helper function to create and wrap the environment."""
    env = gym.make(env_id, reward_type='dense') # Use the same reward_type as original training
    env = Monitor(env) # Add Monitor to track episode stats like 'is_success' for callbacks
    return env

# Create the base vectorized environment for training
# This will be wrapped by VecNormalize
base_training_vec_env = DummyVecEnv([make_env])

# Load the VecNormalize statistics for the training environment
vecnormalize_path = os.path.join(original_save_dir, "vecnormalize.pkl") # Or specific checkpoint vecnormalize

if os.path.exists(vecnormalize_path):
    print(f"🔄 Loading training VecNormalize statistics from {vecnormalize_path}")
    # Load the saved VecNormalize stats and apply them to the base_training_vec_env
    training_env = VecNormalize.load(vecnormalize_path, base_training_vec_env)
    # For fine-tuning, continue to update normalization stats
    training_env.training = True
    # Ensure norm_reward setting matches original if it was part of VecNormalize config (usually False for HER)
    # training_env.norm_reward = False # Explicitly set if needed, load should handle it.
    print("✅ Training VecNormalize loaded and configured.")
else:
    print(f"⚠️ Training VecNormalize file not found at {vecnormalize_path}. Using new VecNormalize wrapper.")
    print("This might lead to unexpected behavior if the original model relied on specific normalization.")
    training_env = VecNormalize(base_training_vec_env, norm_obs=True, norm_reward=False, gamma=0.98) # Match original gamma if needed


# === Load Pre-trained Model ===
load_model_path = os.path.join(original_save_dir, model_to_load_name)
if not os.path.exists(load_model_path):
    raise FileNotFoundError(f"Model to load not found: {load_model_path}. Please check the path and filename.")

print(f"🔄 Loading pre-trained model from {load_model_path}")

# Custom objects for loading. Crucially, HerReplayBuffer needs the env.
# The learning rate here will be the new initial learning rate for fine-tuning.
custom_objects = {
    "learning_rate": 5e-5,  # Example: A potentially smaller learning rate for fine-tuning
    "replay_buffer_class": HerReplayBuffer,
    "replay_buffer_kwargs": dict(
        n_sampled_goal=16, # Must match original if not re-specifying everything
        goal_selection_strategy='future', # Must match original
    ),
    "gamma": 0.98, # Ensure consistency if these were non-default
    "tau": 0.005,
    # "buffer_size": int(1e6), # Usually part of saved model, but can be specified
}

# Load the model, providing the training_env directly
model = SAC.load(
    load_model_path,
    env=training_env, # CRITICAL FIX: Pass the prepared training_env here
    custom_objects=custom_objects,
    # device='cuda' # Uncomment if you want to force a device, e.g., 'cuda' or 'cpu'
)

# model.set_env(training_env) # This is now redundant as env is passed to load and set internally.

print(f"✅ Model loaded. Policy learning rate: {model.learning_rate}, Entropy coef: {model.ent_coef if hasattr(model, 'ent_coef') else 'N/A (auto?)'}")


# === Set up TensorBoard logger for Fine-Tuning ===
finetune_tb_log_path = os.path.join(finetune_save_dir, "tb_log_ft")
new_logger_ft = configure(finetune_tb_log_path, ["stdout", "tensorboard"])
model.set_logger(new_logger_ft)
print(f"📝 Logging fine-tuning to: {finetune_tb_log_path}")

# === Callbacks for Fine-Tuning ===
# Setup evaluation environment for EvalCallback
base_eval_env_ft = DummyVecEnv([make_env]) # Create a separate base instance for evaluation

if os.path.exists(vecnormalize_path):
    print(f"🔄 Loading VecNormalize statistics for separate eval_env_ft from {vecnormalize_path}")
    eval_env_ft = VecNormalize.load(vecnormalize_path, base_eval_env_ft)
    eval_env_ft.training = False  # CRUCIAL: Do not update normalization stats during evaluation
    eval_env_ft.norm_reward = False # Match training env's reward normalization setting
    print("✅ EvalCallback's VecNormalize loaded and configured.")
else:
    print(f"⚠️ VecNormalize file not found for EvalCallback's env at {vecnormalize_path}.")
    print("     EvalCallback will use a new VecNormalize wrapper with fresh (non-synced) stats.")
    eval_env_ft = VecNormalize(base_eval_env_ft, norm_obs=True, norm_reward=False, training=False, gamma=0.98)


eval_callback_ft = EvalCallback(
    eval_env_ft, 
    best_model_save_path=finetune_best_model_path,
    log_path=finetune_save_dir,
    eval_freq=5000, # Adjust frequency as needed for fine-tuning
    n_eval_episodes=20, # More episodes for robust evaluation
    deterministic=True,
    render=False,
    verbose=1
)

checkpoint_callback_ft = CheckpointOnSuccessRate(
    save_path=finetune_save_dir,
    success_threshold=0.90, # Potentially a higher threshold if fine-tuning
    eval_freq=5000,
    verbose=1,
    n_eval_episodes=10 # Custom callbacks use the training_env
)

stop_callback_ft = StopTrainingOnSuccessRate(
    success_threshold=0.98, # Higher threshold for stopping fine-tuning
    eval_freq=5000,
    verbose=1,
    n_eval_episodes=10 # Custom callbacks use the training_env
)

callback_ft = CallbackList([eval_callback_ft, checkpoint_callback_ft, stop_callback_ft])

# === Fine-Tune the Model ===
fine_tune_total_timesteps = 250_000 # Example: fine-tune for additional timesteps
log_interval_ft = 25

print(f"🚀 Fine-tuning SAC + HER on {env_id} for {fine_tune_total_timesteps} timesteps...")
# Consider reset_num_timesteps based on how you want to view logs.
# True: Fine-tuning starts from step 0 in new TensorBoard log.
# False: Continues from loaded model's timesteps.
model.learn(
    total_timesteps=fine_tune_total_timesteps,
    log_interval=log_interval_ft,
    callback=callback_ft,
    reset_num_timesteps=True # Recommended for clear fine-tuning logs
)

# === Save Fine-Tuned Model and VecNormalize Stats ===
model.save(finetune_final_model_path)
# Save the updated VecNormalize stats from the fine-tuning *training* environment
training_env.save(os.path.join(finetune_save_dir, "vecnormalize_finetuned.pkl"))
print(f"✅ Fine-tuned model saved to {finetune_final_model_path}")
print(f"✅ Updated VecNormalize stats saved to {os.path.join(finetune_save_dir, 'vecnormalize_finetuned.pkl')}")

# === TensorBoard usage for fine-tuning logs ===
print("\n📊 To view fine-tuning TensorBoard logs:")
print(f"tensorboard --logdir {finetune_tb_log_path}")
print("\n📊 To view original and fine-tuning logs together (if paths are structured correctly):")
# Assumes original logs are in original_save_dir/tb_log
original_tb_log_path = os.path.join(original_save_dir, "tb_log")
print(f"tensorboard --logdir_spec original:{original_tb_log_path},finetuned:{finetune_tb_log_path}")

# Close environments
training_env.close()
eval_env_ft.close()
print("Environments closed.")



🔄 Loading training VecNormalize statistics from ./tensorboard_sac_her_FetchReach-v3/vecnormalize.pkl
✅ Training VecNormalize loaded and configured.
🔄 Loading pre-trained model from ./tensorboard_sac_her_FetchReach-v3/checkpoint_model.zip
✅ Model loaded. Policy learning rate: 5e-05, Entropy coef: auto
Logging to ./tensorboard_sac_her_FetchReach-v3_finetuned/tb_log_ft
📝 Logging fine-tuning to: ./tensorboard_sac_her_FetchReach-v3_finetuned/tb_log_ft
🔄 Loading VecNormalize statistics for separate eval_env_ft from ./tensorboard_sac_her_FetchReach-v3/vecnormalize.pkl
✅ EvalCallback's VecNormalize loaded and configured.
🚀 Fine-tuning SAC + HER on FetchReach-v3 for 250000 timesteps...
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 50       |
|    ep_rew_mean     | -16      |
|    success_rate    | 0.16     |
| time/              |          |
|    episodes        | 25       |
|    fps             | 185      |
|    time_elapsed    | 6        |
|    tot

In [2]:
import os
import gymnasium as gym
import gymnasium_robotics
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import deque
import random
from torch.utils.tensorboard import SummaryWriter # Import TensorBoard

config = {
    "env_name": "FetchReach-v3",
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "steps": 20000, # Total environment steps for training
    "hidden_dim": 256,
    "log_every": 1000, # Log metrics every N steps
    "imag_horizon": 15, # Imagination horizon length (H)
    "batch_size": 64,
    "buffer_size": 100000,
    "prefill_steps": 1000, # Steps to prefill buffer with random actions
    "gamma": 0.99, # Discount factor for rewards
    "learning_rate": 3e-4,
    "action_scale": 0.05, # Scale factor for actor's output actions
    "dyn_loss_scale": 0.5, # Weight for the dynamics prediction loss in World Model
    "actor_critic_update_freq": 1, # How often to update actor/critic per env step (after prefill)
    "world_model_update_freq": 1, # How often to update world model per env step (after prefill)
}

class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)

    def push(self, obs, action, reward, next_obs, done):
        # Ensure obs and next_obs are base numpy arrays if they are dicts
        obs_flat = obs['observation'] if isinstance(obs, dict) else obs
        next_obs_flat = next_obs['observation'] if isinstance(next_obs, dict) else next_obs
        self.buffer.append((obs_flat, action, reward, next_obs_flat, done))

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        obs, act, rew, next_obs, done = map(np.stack, zip(*batch))
        return (
            torch.tensor(obs, dtype=torch.float32).to(config["device"]),
            torch.tensor(act, dtype=torch.float32).to(config["device"]),
            torch.tensor(rew, dtype=torch.float32).unsqueeze(1).to(config["device"]),
            torch.tensor(next_obs, dtype=torch.float32).to(config["device"]),
            torch.tensor(done, dtype=torch.float32).unsqueeze(1).to(config["device"])
        )

    def __len__(self):
        return len(self.buffer)

class WorldModel(nn.Module):
    def __init__(self, obs_dim, act_dim):
        super().__init__()
        # Encodes observation s_t to an embedding e_t
        self.encoder = nn.Sequential(
            nn.Linear(obs_dim, config["hidden_dim"]),
            nn.ReLU(),
            # Optionally add more layers or different activations
        )
        # GRU cell for dynamics: h_t = GRU(cat(e_t, a_t), h_{t-1})
        # Input to GRU is concatenation of observation embedding and action
        self.dynamics = nn.GRUCell(config["hidden_dim"] + act_dim, config["hidden_dim"])
        # Predicts reward from the GRU's hidden state h_t
        self.reward_head = nn.Linear(config["hidden_dim"], 1)
        # Head to predict next observation embedding (optional, but part of many world models)
        # For this script, dynamics output (hidden state) IS the predicted next latent state.

    def forward(self, obs_embed, action, prev_gru_state):
        """
        Applies the dynamics model for one step.
        Args:
            obs_embed (Tensor): Embedding of the current observation (e_t).
            action (Tensor): Action taken (a_t).
            prev_gru_state (Tensor): Previous hidden state of the GRU (h_{t-1}).
        Returns:
            next_gru_state (Tensor): Next hidden state of the GRU (h_t), which is the predicted latent.
            pred_reward (Tensor): Predicted reward (r_t).
        """
        # Concatenate current observation embedding and action for GRU input
        gru_input = torch.cat([obs_embed, action], dim=-1)
        next_gru_state = self.dynamics(gru_input, prev_gru_state)
        pred_reward = self.reward_head(next_gru_state)
        return next_gru_state, pred_reward

    def encode_obs(self, obs):
        """Encodes a raw observation."""
        return self.encoder(obs)

class Actor(nn.Module):
    def __init__(self, feat_dim, act_dim): # feat_dim is hidden_dim (GRU state dim)
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(feat_dim, config["hidden_dim"]),
            nn.ReLU(),
            nn.Linear(config["hidden_dim"], config["hidden_dim"]), # Added another layer
            nn.ReLU(),
            nn.Linear(config["hidden_dim"], act_dim),
            nn.Tanh() # Outputs actions in [-1, 1]
        )
        self.action_scale = torch.tensor(config["action_scale"], dtype=torch.float32)

    def forward(self, latent_state_features):
        # Scale Tanh output to desired action range
        return self.action_scale.to(latent_state_features.device) * self.net(latent_state_features)

class Critic(nn.Module):
    def __init__(self, feat_dim): # feat_dim is hidden_dim (GRU state dim)
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(feat_dim, config["hidden_dim"]),
            nn.ReLU(),
            nn.Linear(config["hidden_dim"], config["hidden_dim"]), # Added another layer
            nn.ReLU(),
            nn.Linear(config["hidden_dim"], 1) # Outputs a scalar value
        )

    def forward(self, latent_state_features):
        return self.net(latent_state_features)

def imagine_rollout(actor, world_model, start_latent_state, horizon):
    """
    Generates an imagined trajectory using the actor and world model.
    Args:
        actor (Actor): The policy network.
        world_model (WorldModel): The learned world model.
        start_latent_state (Tensor): The initial latent state for imagination (batch_size, hidden_dim).
        horizon (int): The length of the imagined trajectory.
    Returns:
        List of tuples: [(imag_latent_k, imag_reward_k), ...], where k is from 1 to H.
                        imag_latent_k is the k-th imagined latent state (output of dynamics).
                        imag_reward_k is the k-th predicted reward.
    """
    imag_latents_rewards = []
    current_latent = start_latent_state # z_0

    for _ in range(horizon):
        action = actor(current_latent) # a_k ~ actor(z_k)
        # In our WorldModel, forward_through_dynamics is essentially the dynamics step.
        # It needs an "observation embedding" and "action" for its GRU input part.
        # During imagination, the "observation embedding" is the current_latent itself.
        # The GRU's "previous hidden state" is also the current_latent.
        # So, z_{k+1}, r_{k+1} = WM_dynamics(z_k, a_k, z_k)
        next_latent, pred_reward = world_model.forward(current_latent, action, current_latent)
        imag_latents_rewards.append((next_latent, pred_reward))
        current_latent = next_latent # Update for next imagination step
    return imag_latents_rewards


def make_env():
    env = gym.make(config["env_name"], reward_type="dense") # Ensure dense reward explicitly
    # Seed for reproducibility, though other sources of randomness exist (torch, numpy)
    # For full reproducibility, seed torch, numpy, and random.
    # obs, info = env.reset(seed=42) # Seeding here is good
    return env

def train():
    env = make_env()
    # Seeding for external libraries
    torch.manual_seed(42)
    np.random.seed(42)
    random.seed(42)
    obs_dict_init, _ = env.reset(seed=42)


    obs_space_dims = env.observation_space["observation"].shape[0]
    act_dim = env.action_space.shape[0]

    buffer = ReplayBuffer(config["buffer_size"])
    writer = SummaryWriter(log_dir=f"runs/{config['env_name']}_DreamerV3_style_{random.randint(0,10000)}") # Tensorboard writer

    # Initialize models
    wm = WorldModel(obs_space_dims, act_dim).to(config["device"])
    actor = Actor(config["hidden_dim"], act_dim).to(config["device"])
    critic = Critic(config["hidden_dim"]).to(config["device"])

    # Optimizers
    optim_wm = torch.optim.Adam(wm.parameters(), lr=config["learning_rate"], eps=1e-5) # Added eps for stability
    optim_actor = torch.optim.Adam(actor.parameters(), lr=config["learning_rate"], eps=1e-5)
    optim_critic = torch.optim.Adam(critic.parameters(), lr=config["learning_rate"], eps=1e-5)

    # Initial observation and agent state
    obs_dict, _ = env.reset()
    current_obs = obs_dict["observation"]
    
    # Prefill replay buffer with random actions
    print(f"Prefilling buffer with {config['prefill_steps']} random steps...")
    for _ in range(config["prefill_steps"]):
        action = env.action_space.sample()
        next_obs_dict, env_reward, done, truncated, _ = env.step(action)
        
        # Dense reward shaping for FetchReach
        reward = -np.linalg.norm(next_obs_dict["achieved_goal"] - next_obs_dict["desired_goal"])
        
        buffer.push(current_obs, action, reward, next_obs_dict["observation"], done or truncated)
        current_obs = next_obs_dict["observation"]
        if done or truncated:
            obs_dict, _ = env.reset()
            current_obs = obs_dict["observation"]
    print("Buffer prefill complete.")

    # Reset environment and agent state for training phase
    obs_dict, _ = env.reset()
    current_obs = obs_dict["observation"]
    agent_gru_state = torch.zeros(1, config["hidden_dim"]).to(config["device"]) # Agent's recurrent state h_{t-1}
    prev_action_tensor = torch.zeros(1, act_dim).to(config["device"])         # Agent's previous action a_{t-1}
    
    total_episodes_done = 0
    current_episode_reward = 0
    total_env_steps = 0

    for step in range(config["steps"]):
        total_env_steps +=1
        obs_tensor = torch.tensor(current_obs, dtype=torch.float32).unsqueeze(0).to(config["device"])

        with torch.no_grad():
            current_obs_embedding = wm.encode_obs(obs_tensor) # e_t = Enc(s_t)
            # Update agent's GRU state: h_t = GRU(cat(e_t, a_{t-1}), h_{t-1})
            # WorldModel.forward takes (obs_embed, action, prev_gru_state)
            next_agent_gru_state, _ = wm.forward(current_obs_embedding, prev_action_tensor, agent_gru_state)
            # Actor selects action based on current agent state h_t
            action_np = actor(next_agent_gru_state).squeeze(0).cpu().numpy()
        
        agent_gru_state = next_agent_gru_state # Update agent's state for the next step h_t becomes h_{t-1}

        # Interact with environment
        next_obs_dict, env_reward, done, truncated, _ = env.step(action_np)
        next_obs_env = next_obs_dict["observation"]
        
        # Dense reward shaping for FetchReach
        shaped_reward = -np.linalg.norm(next_obs_dict["achieved_goal"] - next_obs_dict["desired_goal"])
        
        buffer.push(current_obs, action_np, shaped_reward, next_obs_env, done or truncated)
        current_episode_reward += shaped_reward # Use shaped reward for learning

        # Prepare for next iteration
        current_obs = next_obs_env
        prev_action_tensor = torch.tensor(action_np, dtype=torch.float32).unsqueeze(0).to(config["device"])

        if done or truncated:
            writer.add_scalar('Reward/EpisodeActual', current_episode_reward, total_episodes_done)
            print(f"Step {step+1}, Episode {total_episodes_done+1}: Total Reward = {current_episode_reward:.2f}")
            current_episode_reward = 0
            total_episodes_done += 1
            
            obs_dict, _ = env.reset()
            current_obs = obs_dict["observation"]
            agent_gru_state = torch.zeros(1, config["hidden_dim"]).to(config["device"]) # Reset GRU state
            prev_action_tensor = torch.zeros(1, act_dim).to(config["device"])    # Reset prev_action

        if len(buffer) < config["batch_size"]: # Wait for enough samples for a batch
            continue

        # --- Training Phase ---
        # Sample a batch from replay buffer
        obs_b, act_b, rew_b, next_obs_b, done_b = buffer.sample(config["batch_size"])

        # 1. Train World Model
        if total_env_steps % config["world_model_update_freq"] == 0:
            optim_wm.zero_grad()
            
            # Encode current observations from batch: e_t = Enc(s_t)
            obs_embed_b = wm.encode_obs(obs_b)
            # For GRU, h_{t-1} for the batch. Using zeros as initial state for one-step predictions.
            initial_gru_state_batch = torch.zeros(config["batch_size"], config["hidden_dim"], device=config["device"])
            
            # Predict next latent state and reward: h_t_pred, r_t_pred = WM(e_t, a_t, h_{t-1}=0)
            h_t_pred_from_dynamics, pred_reward_batch = wm.forward(obs_embed_b, act_b, initial_gru_state_batch)

            # Reward prediction loss
            loss_wm_rew = F.mse_loss(pred_reward_batch, rew_b)

            # Dynamics prediction loss: predict embedding of next observation
            with torch.no_grad():
                target_next_latent_embed = wm.encode_obs(next_obs_b).detach() # e_{t+1} = Enc(s_{t+1})
            loss_wm_dyn = F.mse_loss(h_t_pred_from_dynamics, target_next_latent_embed)
            
            loss_wm = loss_wm_rew + config["dyn_loss_scale"] * loss_wm_dyn
            loss_wm.backward()
            #torch.nn.utils.clip_grad_norm_(wm.parameters(), 100.0) # Optional gradient clipping
            
            torch.nn.utils.clip_grad_norm_(wm.parameters(), 100.0)
            # ...
            torch.nn.utils.clip_grad_norm_(critic.parameters(), 100.0)
            # ...
            torch.nn.utils.clip_grad_norm_(actor.parameters(), 100.0)
            
            # Update world model parameters            
            optim_wm.step()
            
            writer.add_scalar('Loss/WM_Reward', loss_wm_rew.item(), step)
            writer.add_scalar('Loss/WM_Dynamics', loss_wm_dyn.item(), step)
            writer.add_scalar('Loss/WM_Total', loss_wm.item(), step)
            writer.add_scalar('Values/PredReward_Mean_WM', pred_reward_batch.mean().item(), step)

        # Latent state to start imagination from (this is h_t_pred from the world model)
        # Detach because we don't want gradients flowing back to WM from actor/critic losses via this start state.
        start_imag_latent = h_t_pred_from_dynamics.detach()


        # 2. Train Actor and Critic using imagined rollouts
        if total_env_steps % config["actor_critic_update_freq"] == 0:
            # Generate imagined trajectories
            # rollout contains [(h_1, r_1), ..., (h_H, r_H)] where H is imag_horizon
            rollout = imagine_rollout(actor, wm, start_imag_latent, config["imag_horizon"])
            
            # Unpack imagined latents and rewards
            # imag_latents: (H, batch_size, hidden_dim) -> h_1, ..., h_H
            # imag_rewards: (H, batch_size, 1)       -> r_1, ..., r_H
            imag_latents, imag_rewards = zip(*rollout)
            imag_latents = torch.stack(imag_latents) 
            imag_rewards = torch.stack(imag_rewards)

            # Calculate value targets (lambda-returns or GAE-style) for critic
            # Here using N-step bootstrapped returns: R_t = r_t + gamma * R_{t+1}, with R_H = V_critic(h_H)
            value_targets = torch.zeros_like(imag_rewards)
            with torch.no_grad():
                # Bootstrap value for the very last state of the horizon V(h_H)
                last_imag_state_value = critic(imag_latents[-1]) # V(h_H)

            # Calculate targets iteratively from H-1 down to 0
            # V_target(h_k) = r_k + gamma * V_target(h_{k+1})
            # V_target(h_H) = r_H + gamma * V_critic(h_H) (adjusting this for typical Bellman)
            # Let value_targets[k] be the target for critic(imag_latents[k])
            
            next_value_target = last_imag_state_value
            for t in reversed(range(config["imag_horizon"])):
                if t == config["imag_horizon"] - 1: # For V(h_H)
                    value_targets[t] = imag_rewards[t] + config["gamma"] * (1.0 - 0.0) * next_value_target # No (1-done) in imagination
                else: # For V(h_k), k < H
                    value_targets[t] = imag_rewards[t] + config["gamma"] * (1.0 - 0.0) * value_targets[t+1]
            
            # Critic Update
            # Predicted values V(h_k) for all imagined states by current critic
            # Note: imag_latents are already detached from WM graph for actor/critic training
            # if start_imag_latent was detached. However, gradients need to flow through actor for its update.
            # For critic loss, detach imag_latents to ensure targets are fixed.
            value_preds_critic = critic(imag_latents.detach()) # V_current(h_k)
            loss_critic = F.mse_loss(value_preds_critic, value_targets.detach())
            
            optim_critic.zero_grad()
            loss_critic.backward()
            #torch.nn.utils.clip_grad_norm_(critic.parameters(), 100.0) # Optional
            optim_critic.step()
            writer.add_scalar('Loss/Critic', loss_critic.item(), step)

            # Actor Update
            # Re-evaluate value predictions with current critic, allowing gradients to flow back to actor
            # The actor's actions influenced imag_latents through the (fixed for this update) world model.
            # We want to update actor to produce actions leading to high-value latents.
            value_preds_actor = critic(imag_latents) # Gradients will flow through imag_latents to actor
            loss_actor = -value_preds_actor.mean() # Maximize the mean predicted value of imagined states

            optim_actor.zero_grad()
            loss_actor.backward()
            #torch.nn.utils.clip_grad_norm_(actor.parameters(), 100.0) # Optional
            optim_actor.step()
            writer.add_scalar('Loss/Actor', loss_actor.item(), step)
            writer.add_scalar('Values/ImaginedReturn_Mean', value_targets.mean().item(), step)


        if (step + 1) % config["log_every"] == 0:
            print(f"--- Training Step {step+1}/{config['steps']} ---")
            if 'loss_wm' in locals(): # Check if WM was updated in this step block
                 print(f"  WM Losses: Reward={loss_wm_rew.item():.4f}, Dynamics={loss_wm_dyn.item():.4f}, Total={loss_wm.item():.4f}")
            if 'loss_critic' in locals():
                 print(f"  Actor-Critic Losses: Critic={loss_critic.item():.4f}, Actor={loss_actor.item():.4f}")
            print(f"  Buffer Size: {len(buffer)}")
            # Log average action magnitudes if desired
            writer.add_scalar('Params/ActionNorm_Mean', torch.norm(prev_action_tensor, p=2).item(), step)


    env.close()
    writer.close()
    print("Training complete.")

if __name__ == "__main__":
    train()

Prefilling buffer with 1000 random steps...
Buffer prefill complete.
Step 50, Episode 1: Total Reward = -28.20
Step 100, Episode 2: Total Reward = -23.74
Step 150, Episode 3: Total Reward = -29.36
Step 200, Episode 4: Total Reward = -22.25
Step 250, Episode 5: Total Reward = -29.98
Step 300, Episode 6: Total Reward = -32.52
Step 350, Episode 7: Total Reward = -18.45
Step 400, Episode 8: Total Reward = -22.38
Step 450, Episode 9: Total Reward = -27.04
Step 500, Episode 10: Total Reward = -21.68
Step 550, Episode 11: Total Reward = -22.06
Step 600, Episode 12: Total Reward = -33.06
Step 650, Episode 13: Total Reward = -30.31
Step 700, Episode 14: Total Reward = -27.18
Step 750, Episode 15: Total Reward = -29.82
Step 800, Episode 16: Total Reward = -21.21
Step 850, Episode 17: Total Reward = -28.49
Step 900, Episode 18: Total Reward = -25.43
Step 950, Episode 19: Total Reward = -29.98
Step 1000, Episode 20: Total Reward = -19.15
--- Training Step 1000/20000 ---
  WM Losses: Reward=0.0064,