In [92]:
import gymnasium as gym
from gymnasium import spaces
import numpy as np

class FunctionMinimizationEnv(gym.Env):
    def __init__(self):
        super().__init__()
        self.param_dim = 3  # Example: 3 parameters
        self.action_space = spaces.Box(low=-0.01, high=0.01, shape=(self.param_dim,), dtype=np.float32)
        self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(self.param_dim,), dtype=np.float32)
        self.state = None

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        self.state = np.random.uniform(-1, 1, size=(self.param_dim,))
        # Gymnasium reset returns (obs, info)
        return self.state, {}

    def step(self, action):
        self.state = self.state + action
        value = self.target_function(self.state)
        reward = -value  # Negative because we want to minimize
        terminated = value < -5 or value > 5  # Set your own termination condition
        truncated = False   # Set your own truncation condition
        info = {}
        # Gymnasium step returns (obs, reward, terminated, truncated, info)
        return self.state, reward, terminated, truncated, info

    def target_function(self, x):
        return np.sum(x**2)  # Example: minimize sum of squares


In [93]:
from stable_baselines3 import SAC
from gymnasium.wrappers import TimeLimit

env = FunctionMinimizationEnv()
env = TimeLimit(env, max_episode_steps=5000)
model = SAC("MlpPolicy", env, verbose=1)

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


In [94]:
model.learn(total_timesteps=1000)

<stable_baselines3.sac.sac.SAC at 0x14b49c7e480>

In [95]:
obs, info = env.reset()
terminated = False
truncated = False

while not (terminated or truncated):
    action, _ = model.predict(obs, deterministic=True)
    obs, reward, terminated, truncated, info = env.step(action)
# obs now contains the parameter vector found by the agent

In [96]:
print(obs, reward, terminated, truncated, info)

[-0.50335988 -0.15649646  0.93774354] -1.1572252491171926 False True {}
