In [1]:
import gymnasium as gym
from stable_baselines3 import TD3
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.evaluation import evaluate_policy
import os

In [None]:
# Import custom environment
from drone_env import DroneEnv  # update path as needed

def make_env():
    return Monitor(DroneEnv())  # Wrap with Monitor for logging

def main():
    log_dir = "./logs"
    os.makedirs(log_dir, exist_ok=True)

    # Create training environment
    env = make_vec_env(make_env, n_envs=1)

    # Evaluation environment
    eval_env = make_env()

    # Callback to save best model
    eval_callback = EvalCallback(
        eval_env,
        best_model_save_path=log_dir + "/best_model",
        log_path=log_dir,
        eval_freq=1500,
        deterministic=True,
        render=False
    )

    # Create the model
    model = TD3(
        "MlpPolicy",
        env,
        verbose=1,
        learning_rate=0.001,
        batch_size=256,
        buffer_size=100_000,
        learning_starts=10_000,
        train_freq=(1, "step"),
        tau=0.005,
        gamma=0.99,
        tensorboard_log=log_dir,
    )

    # Train the model
    model.learn(total_timesteps=600_000, callback=eval_callback)

    # Save final model
    model.save(os.path.join(log_dir, "final_model"))

    print("Training complete. Model saved.")

    mean_reward, std_reward = evaluate_policy(model, eval_env, n_eval_episodes=10)
    print(f"Final evaluation reward: {mean_reward:.2f} ± {std_reward:.2f}")

if __name__ == "__main__":
    main()


Using cpu device
Logging to ./logs\TD3_2
f_agent: [2.510161  2.0643988 2.1028094 0.7309852], f_true: [1.8271125 1.8271125 1.8271125 1.8271125]
f_agent: [1.4104319  2.395739   0.84558946 3.0403118 ], f_true: [1.64044899 1.82277858 1.71311319 2.13034605]
f_agent: [2.8475566 3.4400878 1.2179065 2.7768917], f_true: [1.88946874 1.61586626 1.97510817 1.8192732 ]
f_agent: [1.3837922 1.7965621 3.6068082 1.4957123], f_true: [1.89744132 1.30705123 2.30030777 1.73659593]
f_agent: [3.3803306 1.1294178 1.8834254 3.5383763], f_true: [2.02288475 1.4521031  1.90490846 1.8435273 ]
f_agent: [1.4816927  0.58531547 2.4723089  0.9622712 ], f_true: [1.89269988 1.70491766 1.92151475 1.65378339]
f_agent: [3.0364294 2.038726  2.4681816 2.553276 ], f_true: [1.85751336 1.98149324 1.58576912 1.78660542]
f_agent: [1.2821214  3.4657536  0.80919147 3.2198663 ], f_true: [1.75099377 2.0980689  1.50905898 1.80037802]
f_agent: [0.80482364 3.2996874  2.2801785  3.4001894 ], f_true: [2.10271463 1.69244411 1.89091884 1.446