<a href="https://colab.research.google.com/github/shubhamt2897/Gymnasium_Robotics_Tutorial/blob/main/src/Fetch/Fetch_Slide_SAC_hyperparameters.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# 1. Pin NumPy to our standard, stable version FIRST for maximum compatibility
!pip install "numpy==1.26.4"

# 2. Install Optuna for tuning and all our other required libraries
!pip install -U gymnasium gymnasium-robotics stable-baselines3["extra"] mujoco optuna plotly

In [None]:
# Import all necessary libraries
import gymnasium as gym
import gymnasium_robotics
from stable_baselines3 import SAC, HerReplayBuffer
import optuna
import torch as th
import numpy as np
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv

# --- Configuration ---
ENV_ID = "FetchSlide-v3"
# Let's increase the number of trials since we have a GPU
N_TRIALS = 50
# Training timesteps for EACH trial. Keep this relatively low to finish faster.
N_TIMESTEPS = 25000
# Number of episodes to evaluate EACH trained model
N_EVAL_EPISODES = 25


# --- The Objective Function for Optuna ---
def objective(trial: optuna.Trial) -> float:
    """
    Trains and evaluates an SAC model with hyperparameters suggested by Optuna.
    """
    print(f"\n--- Starting Trial #{trial.number} ---")

    # 1. Suggest Hyperparameters
    learning_rate = trial.suggest_float("learning_rate", 1e-5, 1e-3, log=True)
    net_arch_str = trial.suggest_categorical("net_arch", ["small", "medium", "big"])
    net_arch_map = {"small": [64, 64], "medium": [128, 128], "big": [256, 256]}
    net_arch = net_arch_map[net_arch_str]
    policy_kwargs = dict(net_arch=net_arch)

    # 2. Create and Train the Model
    train_env = gym.make(ENV_ID)
    replay_buffer_class = HerReplayBuffer
    replay_buffer_kwargs = dict(n_sampled_goal=4, goal_selection_strategy="future")

    model = SAC(
        "MultiInputPolicy",
        train_env,
        learning_rate=learning_rate,
        policy_kwargs=policy_kwargs,
        replay_buffer_class=replay_buffer_class,
        replay_buffer_kwargs=replay_buffer_kwargs,
        verbose=0,
        device="cuda"
    )

    model.learn(total_timesteps=N_TIMESTEPS)

    # 3. Evaluate the Model
    eval_env = gym.make(ENV_ID)
    successful_episodes = 0
    for _ in range(N_EVAL_EPISODES):
        obs, _ = eval_env.reset()
        done = False
        while not done:
            action, _ = model.predict(obs, deterministic=True)
            obs, _, terminated, truncated, info = eval_env.step(action)
            done = terminated or truncated
            if done and info.get('is_success'):
                successful_episodes += 1

    eval_env.close()
    success_rate = successful_episodes / N_EVAL_EPISODES
    print(f"Trial #{trial.number} Finished. Success Rate: {success_rate:.2f}")

    # 4. Return the score
    return success_rate

# --- Start the Tuning Process ---
# Create the study and start the optimization
study = optuna.create_study(direction="maximize")
# We can set a timeout to avoid exceeding Colab's limits, e.g., 2 hours (7200s)
study.optimize(objective, n_trials=N_TRIALS, timeout=7200)

# --- Print the Best Results ---
print("\n--- Hyperparameter Tuning Complete ---")
print(f"Number of finished trials: {len(study.trials)}")
print("Best trial:")
best_trial = study.best_trial
print(f"  Value (Success Rate): {best_trial.value:.4f}")
print("  Params: ")
for key, value in best_trial.params.items():
    print(f"    {key}: {value}")

In [None]:
# --- Visualize the Results ---
# Requires the 'plotly' library we installed earlier

# Show the optimization history
fig1 = optuna.visualization.plot_optimization_history(study)
fig1.show()

# Show the parameter importance
fig2 = optuna.visualization.plot_param_importances(study)
fig2.show()

# Show slice plots to see how each parameter affects the outcome
fig3 = optuna.visualization.plot_slice(study, params=["learning_rate", "net_arch"])
fig3.show()