In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from stable_baselines3 import SAC
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.results_plotter import load_results, ts2xy
from asv_glider_bearing_dist_env import AsvGliderBearingEnv

# --- 1. Define the Rollout Callback ---
class TrajectoryPlotCallback(BaseCallback):
    def __init__(self, eval_env, render_freq=10000, log_dir="./rollouts/"):
        super().__init__()
        self.eval_env = eval_env
        self.render_freq = render_freq
        self.log_dir = log_dir
        os.makedirs(self.log_dir, exist_ok=True)

    def _on_step(self) -> bool:
        # Check if it's time to run a rollout
        if self.n_calls % self.render_freq == 0:
            obs, _ = self.eval_env.reset()
            asv_history = []
            glider_history = []
            done = False
            
            # Run one full test episode deterministically
            while not done:
                action, _ = self.model.predict(obs, deterministic=True)
                obs, reward, terminated, truncated, _ = self.eval_env.step(action)
                
                # Store un-normalized positions for plotting
                asv_history.append(self.eval_env.unwrapped.asv_pos.copy())
                glider_history.append(self.eval_env.unwrapped.glider_pos.copy())
                done = terminated or truncated

            # Convert to arrays for plotting
            asv_history = np.array(asv_history)
            glider_history = np.array(glider_history)

            # Generate the trajectory plot
            plt.figure(figsize=(6, 6))
            plt.plot(glider_history[:, 0], glider_history[:, 1], 'g--', label="Glider Path", alpha=0.6)
            plt.plot(asv_history[:, 0], asv_history[:, 1], 'b-', label="ASV Path", linewidth=2)
            plt.scatter(asv_history[0, 0], asv_history[0, 1], c='blue', label="ASV Start")
            plt.scatter(glider_history[0, 0], glider_history[0, 1], c='green', label="Glider Start")
            
            plt.title(f"Rollout at Step {self.n_calls}")
            plt.xlabel("X Position (m)")
            plt.ylabel("Y Position (m)")
            plt.legend()
            plt.grid(True)
            
            # Save the plot
            save_path = os.path.join(self.log_dir, f"rollout_{self.n_calls}.png")
            plt.savefig(save_path)
            plt.close()
            print(f">>> Saved trajectory rollout to {save_path}")

        return True

# --- 2. Setup Directories and Environments ---
log_dir = "./sac_asv_logs/"
os.makedirs(log_dir, exist_ok=True)

# Main training environment
env = AsvGliderBearingEnv()
env = Monitor(env, log_dir)

# Separate environment for the callback to use (prevents interference with training state)
eval_env = AsvGliderBearingEnv()

# --- 3. Initialize Model and Callback ---
model = SAC(
    "MlpPolicy", # Policy type
    env,
    verbose=1,
    learning_rate=7e-5, 
    gamma=0.99,
    buffer_size=1_000_000,
    learning_starts=10_000,
    batch_size=256,
    tau=0.005,
    train_freq=1,
    gradient_steps=1,
    device="cuda",
)

# Initialize callback: plots every 10k steps
plot_callback = TrajectoryPlotCallback(eval_env, render_freq=10000)

from stable_baselines3.common.callbacks import CallbackList, EvalCallback

# --- 4. Setup Evaluation & Real-time Printing ---
# This environment is used to calculate the "mean reward" periodically
eval_env = AsvGliderBearingEnv()
eval_env = Monitor(eval_env, "./logs/eval/") # Optional: logs eval results separately

# 1. EvalCallback: This handles the mean reward printing you asked for
eval_callback = EvalCallback(
    eval_env, 
    best_model_save_path="./logs/best_model",
    log_path="./logs/results", 
    eval_freq=5000,         # How often to calculate mean reward (every 5k steps)
    deterministic=True, 
    render=False
)

# 2. Combine with your trajectory plotter
callbacks = CallbackList([eval_callback, plot_callback])

# --- 5. Train with Error Handling ---
try:
    print("Starting training. Look for 'Eval num_timesteps' in the output for mean rewards.")
    model.learn(
        total_timesteps=100_000, 
        callback=callbacks,
        progress_bar=True  # Adds a nice loading bar in the notebook
    )
except KeyboardInterrupt:
    print("\nTraining interrupted by user.")
finally:
    # Forces the monitor.csv to write all remaining data to disk
    env.close()
    eval_env.close()
    model.save("sac_asv_bearing_dist_final")
    print("Environment closed and logs flushed. You can now run the plotting code.")