In [1]:
!pip install gym matplotlib glfw mujoco ray[rllib]



In [2]:
import ray
from ray.rllib.algorithms.sac import SAC
from ray.tune.logger import pretty_print

def get_max_action(env_name):
    import gym
    env = gym.make(env_name)
    return env.action_space.high[0]

max_action = get_max_action("Reacher-v4")

config_dict_1 = {
    "env": "Reacher-v4",
    "framework": "torch",
    "num_gpus": 1,
    "num_workers": 0,
    "algorithm": "SAC",
    "use_automatic_entropy_tuning": True,
    "target_entropy": "auto",
    "model": {
        "fcnet_hiddens": [256, 256],
        "fcnet_activation": "relu",
    },
    "optimization": {
        "actor_learning_rate": 0.0001,
        "critic_learning_rate": 0.0001,
        "entropy_learning_rate": 0.0003,
        "learning_starts": 1000,
    },
    "exploration_config": {
        "type": "OrnsteinUhlenbeckNoise",
        "ou_base_scale": 0.5,
        "ou_theta": 0.15,
        "ou_sigma": 0.2,
    },
    "replay_buffer_config": {
        "capacity": 1000000,
        "learning_starts": 1000,
        "type": "MultiAgentReplayBuffer",
    },
    "target_network_update_freq": 5,
    "tau": 0.02,
    "gamma": 0.95,
    "clip_actions": True,
    "train_batch_size": 128,
    "rollout_fragment_length": 1,
    "batch_mode": "complete_episodes",
    "stop": {
        "training_iteration": 10000,
    },
}

config_dict_2 = {
    "env": "Reacher-v4",
    "framework": "torch",
    "num_gpus": 0,
    "num_workers": 0,
    "algorithm": "SAC",
    "use_automatic_entropy_tuning": False,
    "target_entropy": -2,
    "model": {
        "fcnet_hiddens": [64, 64],
        "fcnet_activation": "tanh",
    },
    "optimization": {
        "actor_learning_rate": 0.0005,
        "critic_learning_rate": 0.0005,
        "entropy_learning_rate": 0.0001,
        "learning_starts": 2000,
    },
    "exploration_config": {
        "type": "GaussianNoise",
        "stddev": 0.1,
    },
    "replay_buffer_config": {
        "capacity": 500000,
        "learning_starts": 2000,
        "type": "MultiAgentReplayBuffer",
    },
    "target_network_update_freq": 20,
    "tau": 0.01,
    "gamma": 0.99,
    "clip_actions": True,
    "train_batch_size": 32,
    "rollout_fragment_length": 1,
    "batch_mode": "complete_episodes",
    "stop": {
        "training_iteration": 15000,
    },
}

config_dict_3 = {
    "env": "Reacher-v4",
    "framework": "torch",
    "num_gpus": 0,
    "num_workers": 0,
    "algorithm": "SAC",
    "use_automatic_entropy_tuning": True,
    "target_entropy": "auto",
    "model": {
        "fcnet_hiddens": [128, 128],
        "fcnet_activation": "relu",
    },
    "optimization": {
        "actor_learning_rate": 0.0005,
        "critic_learning_rate": 0.0005,
        "entropy_learning_rate": 0.0005,
        "learning_starts": 500,
    },
    "exploration_config": {
        "type": "OrnsteinUhlenbeckNoise",
        "ou_base_scale": 0.3,
        "ou_theta": 0.45,
        "ou_sigma": 0.3,
    },
    "replay_buffer_config": {
        "capacity": 500000,
        "learning_starts": 500,
        "type": "MultiAgentReplayBuffer",
    },
    "target_network_update_freq": 5,
    "tau": 0.01,
    "gamma": 0.95,
    "clip_actions": True,
    "train_batch_size": 256,
    "rollout_fragment_length": 1,
    "batch_mode": "complete_episodes",
    "stop": {
        "training_iteration": 10000,
    },
}

import gym
import numpy as np

class CustomReacherEnv(gym.Wrapper):
    def __init__(self, env_name):
        super(CustomReacherEnv, self).__init__(gym.make(env_name))

    def step(self, action):
        obs, reward, done, info = self.env.step(action)

        # Modify the reward using a custom function
        modified_reward = self.custom_reward_function(obs, reward, info)

        return obs, modified_reward, done, info

    def custom_reward_function(self, obs, reward, info):
        # Example of a custom reward function
        # Extract relevant information for reward calculation
        distance = np.linalg.norm(obs[-3:-1])  # Distance from fingertip to target
        control_cost = np.sum(np.square(action))  # Cost of action magnitude

        # Customize these weights as needed
        distance_weight = 1.0
        control_weight = 0.1

        # Calculate the modified reward
        modified_reward = -distance_weight * distance - control_weight * control_cost

        return modified_reward

config_dict_4 = {
    "env": CustomReacherEnv("Reacher-v4"),
    "framework": "torch",
    "num_gpus": 0,
    "num_workers": 0,
    "algorithm": "SAC",
    "use_automatic_entropy_tuning": True,
    "target_entropy": "auto",
    "model": {
        "fcnet_hiddens": [400, 300],  # More complex network
        "fcnet_activation": "relu",
    },
    "optimization": {
        "actor_learning_rate": 0.001,  # Slightly higher learning rate
        "critic_learning_rate": 0.001,
        "entropy_learning_rate": 0.001,
        "learning_starts": 1500,  # Start learning after more initial exploration
    },
    "exploration_config": {
        "type": "OrnsteinUhlenbeckNoise",
        "ou_base_scale": 0.2,
        "ou_theta": 0.3,
        "ou_sigma": 0.2,
    },
    "replay_buffer_config": {
        "capacity": 1000000,  # Larger replay buffer
        "learning_starts": 1500,
        "type": "MultiAgentReplayBuffer",
    },
    "target_network_update_freq": 1,
    "tau": 0.005,
    "gamma": 0.99,
    "clip_actions": True,
    "train_batch_size": 128,
    "rollout_fragment_length": 1,
    "batch_mode": "complete_episodes",
    "stop": {
        "training_iteration": 10000,
    },
}


# Initialize Ray
ray.init()

# Create and configure the SAC algorithm
sac_trainer = SAC(config=config_dict_2)

# Training loop
for i in range(500):  # Number of training iterations
    result = sac_trainer.train()
    print(pretty_print(result))

# Cleanup
ray.shutdown()


  return jax_config.define_bool_state('flax_' + name, default, help)
  from jax.nn import normalize
  if (distutils.version.LooseVersion(tf.__version__) <
  deprecation(
  deprecation(
2023-12-21 16:49:09,222	INFO worker.py:1724 -- Started a local Ray instance.
`UnifiedLogger` will be removed in Ray 2.7.
  return UnifiedLogger(config, logdir, loggers=None)
The `JsonLogger interface is deprecated in favor of the `ray.tune.json.JsonLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `CSVLogger interface is deprecated in favor of the `ray.tune.csv.CSVLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `TBXLogger interface is deprecated in favor of the `ray.tune.tensorboardx.TBXLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))


agent_timesteps_total: 100
connector_metrics:
  ObsPreprocessorConnector_ms: 0.011599063873291016
  StateBufferConnector_ms: 0.007033348083496094
  ViewRequirementAgentConnector_ms: 1.1858105659484863
counters:
  num_agent_steps_sampled: 100
  num_agent_steps_trained: 0
  num_env_steps_sampled: 100
  num_env_steps_trained: 0
custom_metrics: {}
date: 2023-12-21_16-49-13
done: false
episode_len_mean: 50.0
episode_media: {}
episode_reward_max: -42.27276631577412
episode_reward_mean: -47.61032826383281
episode_reward_min: -52.9478902118915
episodes_this_iter: 2
episodes_total: 2
hostname: 4900223327f7
info:
  learner: {}
  num_agent_steps_sampled: 100
  num_agent_steps_trained: 0
  num_env_steps_sampled: 100
  num_env_steps_trained: 0
iterations_since_restore: 1
node_ip: 172.28.0.12
num_agent_steps_sampled: 100
num_agent_steps_trained: 0
num_env_steps_sampled: 100
num_env_steps_sampled_this_iter: 100
num_env_steps_sampled_throughput_per_sec: 95.10089055959666
num_env_steps_trained: 0
num_e



[1;30;43mStreaming output truncated to the last 5000 lines.[0m
        0.2850484848022461, 0.362335205078125, 0.41675758361816406, 0.680537223815918,
        0.029700279235839844, 0.3155384063720703, 0.34917640686035156, 0.2590975761413574,
        0.12436771392822266, 0.0751352310180664, 0.23122692108154297, 0.14208221435546875]
  num_agent_steps_sampled: 176200
  num_agent_steps_trained: 111808
  num_env_steps_sampled: 176200
  num_env_steps_trained: 111808
  num_target_updates: 3494
iterations_since_restore: 469
node_ip: 172.28.0.12
num_agent_steps_sampled: 176200
num_agent_steps_trained: 111808
num_env_steps_sampled: 176200
num_env_steps_sampled_this_iter: 450
num_env_steps_sampled_throughput_per_sec: 412.8444304163777
num_env_steps_trained: 111808
num_env_steps_trained_this_iter: 288
num_env_steps_trained_throughput_per_sec: 264.2204354664817
num_faulty_episodes: 0
num_healthy_workers: 0
num_in_flight_async_reqs: 0
num_remote_worker_restarts: 0
num_steps_trained_this_iter: 288
p

In [3]:
checkpoint_path = sac_trainer.save("/content")

  and should_run_async(code)


In [4]:
!apt-get update

0% [Working]            Get:1 http://security.ubuntu.com/ubuntu jammy-security InRelease [110 kB]
0% [Connecting to archive.ubuntu.com (91.189.91.83)] [1 InRelease 2,585 B/110 kB 2%] [Connected to c0% [Connecting to archive.ubuntu.com (91.189.91.83)] [Waiting for headers] [Waiting for headers] [Wa                                                                                                    Get:2 https://cloud.r-project.org/bin/linux/ubuntu jammy-cran40/ InRelease [3,626 B]
                                                                                                    0% [Connecting to archive.ubuntu.com (91.189.91.83)] [Waiting for headers] [Waiting for headers]                                                                                                Hit:3 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  InRelease
                                                                                                0% [Waiting for headers]

In [5]:
!apt-get install xvfb --fix-missing

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
The following additional packages will be installed:
  libfontenc1 libxfont2 libxkbfile1 x11-xkb-utils xfonts-base xfonts-encodings xfonts-utils
  xserver-common
The following NEW packages will be installed:
  libfontenc1 libxfont2 libxkbfile1 x11-xkb-utils xfonts-base xfonts-encodings xfonts-utils
  xserver-common xvfb
0 upgraded, 9 newly installed, 0 to remove and 27 not upgraded.
Need to get 7,813 kB of archives.
After this operation, 11.9 MB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu jammy/main amd64 libfontenc1 amd64 1:1.1.4-1build3 [14.7 kB]
Get:2 http://archive.ubuntu.com/ubuntu jammy/main amd64 libxfont2 amd64 1:2.0.5-1build1 [94.5 kB]
Get:3 http://archive.ubuntu.com/ubuntu jammy/main amd64 libxkbfile1 amd64 1:1.1.0-1build3 [71.8 kB]
Get:4 http://archive.ubuntu.com/ubuntu jammy/main amd64 x11-xkb-utils amd64 7.7+5build4 [172 kB]
Get:5 http://archiv

In [6]:
!pip install pyvirtualdisplay

Collecting pyvirtualdisplay
  Downloading PyVirtualDisplay-3.0-py3-none-any.whl (15 kB)
Installing collected packages: pyvirtualdisplay
Successfully installed pyvirtualdisplay-3.0


In [7]:
from pyvirtualdisplay import Display

# Create a virtual display
virtual_display = Display(visible=0, size=(1400, 900))
virtual_display.start()


<pyvirtualdisplay.display.Display at 0x7a06bf6fe110>

In [8]:
import gym
import cv2


# Load the trained model for testing
#algo.restore(checkpoint_path)
# Initialize the environment
env = gym.make("Reacher-v4")

# Setup video writer
fourcc = cv2.VideoWriter_fourcc(*'XVID')
out = cv2.VideoWriter('Reacher-v4_output_sac.avi', fourcc, 20.0, (500, 500))

# Test the agent
state = env.reset()
for _ in range(100):
    action = sac_trainer.compute_single_action(state)
    next_state, _, _, _ = env.step(action)
    frame = env.render(mode='rgb_array')
    frame_resized = cv2.resize(frame, (500, 500))
    out.write(cv2.cvtColor(frame_resized, cv2.COLOR_RGB2BGR))
    state = next_state

# Clean up
env.close()
out.release()

# Close the virtual display
virtual_display.stop()

  and should_run_async(code)
  deprecation(
  deprecation(
See here for more information: https://www.gymlibrary.ml/content/api/[0m
  deprecation(


<pyvirtualdisplay.display.Display at 0x7a06bf6fe110>