In [None]:
!pip install dm_control
!pip install pink-noise-rl
!pip install wandb

In [None]:
import gym
from gym import spaces

from dm_control import suite
from dm_env import specs


def convert_dm_control_to_gym_space(dm_control_space):
    r"""Convert dm_control space to gym space. """
    if isinstance(dm_control_space, specs.BoundedArray):
        space = spaces.Box(low=dm_control_space.minimum, 
                           high=dm_control_space.maximum, 
                           dtype=dm_control_space.dtype)
        assert space.shape == dm_control_space.shape
        return space
    elif isinstance(dm_control_space, specs.Array) and not isinstance(dm_control_space, specs.BoundedArray):
        space = spaces.Box(low=-float('inf'), 
                           high=float('inf'), 
                           shape=dm_control_space.shape, 
                           dtype=dm_control_space.dtype)
        return space
    elif isinstance(dm_control_space, dict):
        space = spaces.Dict({key: convert_dm_control_to_gym_space(value)
                             for key, value in dm_control_space.items()})
        return space


class DMSuiteEnv(gym.Env):
    def __init__(self, domain_name, task_name, task_kwargs=None, environment_kwargs=None, visualize_reward=False):
        self.env = suite.load(domain_name, 
                              task_name, 
                              task_kwargs=task_kwargs, 
                              environment_kwargs=environment_kwargs, 
                              visualize_reward=visualize_reward)
        self.metadata = {'render.modes': ['human', 'rgb_array'],
                         'video.frames_per_second': round(1.0/self.env.control_timestep())}
        print(self.env.observation_spec())
        self.observation_space = convert_dm_control_to_gym_space(self.env.observation_spec())
        print(self.observation_space)
        print("________________________")
        print(self.env.action_spec())
        self.action_space = convert_dm_control_to_gym_space(self.env.action_spec())
        print(self.action_space)
        self.viewer = None
    
    def seed(self, seed):
        return self.env.task.random.seed(seed)
    
    def step(self, action):
        timestep = self.env.step(action)
        observation = timestep.observation
        reward = timestep.reward
        done = timestep.last()
        info = {}
        truncated = False
        return observation, reward, done, info
    
    def reset(self):
        timestep = self.env.reset()
        return timestep.observation
    
    def render(self, mode='human', **kwargs):
        if 'camera_id' not in kwargs:
            kwargs['camera_id'] = 0  # Tracking camera
        use_opencv_renderer = kwargs.pop('use_opencv_renderer', False)
        
        img = self.env.physics.render(**kwargs)
        if mode == 'rgb_array':
            return img
        elif mode == 'human':
            if self.viewer is None:
                if not use_opencv_renderer:
                    from gym.envs.classic_control import rendering
                    self.viewer = rendering.SimpleImageViewer(maxwidth=1024)
                else:
                    from . import OpenCVImageViewer
                    self.viewer = OpenCVImageViewer()
            self.viewer.imshow(img)
            return self.viewer.isopen
        else:
            raise NotImplementedError

    def close(self):
        if self.viewer is not None:
            self.viewer.close()
            self.viewer = None
        return self.env.close()

In [None]:
env1 = ["cartpole", "cartpole", "ball_in_cup", "hopper", "cheetah", "reacher", "pendulum"]
env2 = ["balance_sparse","swingup_sparse","catch","hop","run","hard","swingup"]

In [None]:
import gymnasium as gym
import numpy as np
import torch
from pink import PinkActionNoise
from pink import ColoredActionNoise
from stable_baselines3 import TD3
import time
from tqdm import tqdm

# Define a function to evaluate an episode
def evaluate_episode(model, env):
    obs = env.reset()
    done = False
    total_reward = 0.0
    steps=0
    while steps<1000 and not done:
        action, _ = model.predict(obs, deterministic=True)
        obs, reward, done, _ = env.step(action)
        total_reward += reward
        steps+=1
    return total_reward

# Reproducibility
seed = 0
np.random.seed(seed)
torch.manual_seed(seed)
rng = np.random.default_rng(seed)

for i in range(7):
    # Initialize environment
    env = DMSuiteEnv(env1[i],env2[i])
    action_dim = env.action_space.shape[-1]
    seq_len = 1000
    rng = np.random.default_rng(0)

    # Initialize agents
    model_default = TD3("MultiInputPolicy", env)
    model_pink = TD3("MultiInputPolicy", env)
    model_OU = TD3("MultiInputPolicy", env)

    # Set action noise
    noise_scale = 0.3
    model_pink.action_noise = PinkActionNoise(noise_scale, seq_len, action_dim)
    model_OU.action_noise = ColoredActionNoise(2, noise_scale, seq_len, action_dim)

    # Training parameters
    total_timesteps = 1000000
    eval_frequency = 10000 # Evaluate every 104 interactions
    eval_rollouts = 5

    wandb.init(
        project="Pinkie",
        config = {
        "Total_timesteps": total_timesteps,
        "Eval_frequency": eval_frequency,
        "Eval_rollouts": eval_rollouts,
        "Environment": env1[i] + " " + env2[i]
        }
    )

    #Final average performances
    avg_default=0.0
    avg_pink=0.0
    avg_OU=0.0
    final_default=0.0
    final_pink=0.0
    final_OU=0.0

    # Train agents with evaluation
    # timesteps_so_far = 0
    for timesteps_so_far in tqdm(range(total_timesteps)):
        t1 = time.time()
        # Train the default noise model
        model_default.learn(total_timesteps=eval_frequency)
        t2 = time.time()

        # Evaluate the default noise model
        mean_return_default = 0.0
        for _ in range(eval_rollouts):
            mean_return_default += evaluate_episode(model_default, env)
        mean_return_default /= eval_rollouts
        avg_default+=mean_return_default
        if(timesteps_so_far>=0.95*total_timesteps):
            final_default+=mean_return_default

        print(f"Return (Default): {mean_return_default}")
        print(f"Time taken (Default Model): {t2 - t1:.2f} seconds")
        print(f"Timesteps: {timesteps_so_far}, Mean Return: {mean_return_default}")

        t1=time.time()
        # Train the pink noise model
        model_pink.learn(total_timesteps=eval_frequency)
        # timesteps_so_far += eval_frequency
        t2 = time.time()

        # Evaluate the pink noise model
        mean_return_pink = 0.0
        for _ in range(eval_rollouts):
            mean_return_pink += evaluate_episode(model_pink, env)
        mean_return_pink /= eval_rollouts
        avg_pink+=mean_return_pink
        if(timesteps_so_far>=0.95*total_timesteps):
            final_pink+=mean_return_pink

        print(f"Return (Pink): {mean_return_pink}")
        print(f"Time taken (Pink Noise Model): {t2 - t1:.2f} seconds")
        print(f"Timesteps: {timesteps_so_far}, Mean Return: {mean_return_pink}")

        t1=time.time()
        # Train the pink noise model
        model_OU.learn(total_timesteps=eval_frequency)
        # timesteps_so_far += eval_frequency
        t2 = time.time()

        # Evaluate the pink noise model
        mean_return_OU = 0.0
        for _ in range(eval_rollouts):
            mean_return_OU += evaluate_episode(model_OU, env)
        mean_return_OU/= eval_rollouts
        avg_OU+=mean_return_OU
        if(timesteps_so_far>=0.95*total_timesteps):
            final_OU+=mean_return_OU

        print(f"Return (OU): {mean_return_OU}")
        print(f"Time taken (OU Noise Model): {t2 - t1:.2f} seconds")
        print(f"Timesteps: {timesteps_so_far}, Mean Return: {mean_return_OU}")

        timesteps_so_far += eval_frequency

        wandb.log({
            "mean_return_OU": mean_return_OU,
            "mean_return_pink": mean_return_pink,
            "mean_return_default": mean_return_default,
            "timesteps_so_far": timesteps_so_far
        })

    avg_default/=(total_timesteps/eval_frequency)
    avg_pink/=(total_timesteps/eval_frequency)
    avg_OU/=(total_timesteps/eval_frequency)

final_default/=(0.05*total_timesteps/eval_frequency)
final_pink/=(0.05*total_timesteps/eval_frequency)
final_OU/=(0.05*total_timesteps/eval_frequency)

wandb.log({
    "final_default": final_default,
    "final_pink": final_pink,
    "final_OU": final_OU,
    "avg_default": avg_default,
    "avg_pink": avg_pink,
    "avg_OU": avg_OU
})

print("Mean:")
print(f"White:{avg_default}           Pink:{avg_pink}             OU:{avg_OU}")
print("Final:")
print(f"White:{final_default}           Pink:{final_pink}             OU:{final_OU}")

In [None]:
# env = DMSuiteEnv("cartpole","balance_sparse")
# env = DMSuiteEnv("cartpole","swingup_sparse")
# env = DMSuiteEnv("ball_in_cup","catch")
# env = DMSuiteEnv("hopper","hop")
# env = DMSuiteEnv("cheetah","run")
# env = DMSuiteEnv("cheetah","run")
# env = DMSuiteEnv("pendulum","swingup")
# env = DMSuiteEnv("reacher","hard")