In [None]:
!pip install dm_control
!pip install pink-noise-rl
!pip install wandb

In [None]:
from gym import core, spaces
from dm_control import suite
from dm_env import specs
import numpy as np


def _spec_to_box(spec, dtype):
    def extract_min_max(s):
        assert s.dtype == np.float64 or s.dtype == np.float32
        dim = int(np.prod(s.shape))
        if type(s) == specs.Array:
            bound = np.inf * np.ones(dim, dtype=np.float32)
            return -bound, bound
        elif type(s) == specs.BoundedArray:
            zeros = np.zeros(dim, dtype=np.float32)
            return s.minimum + zeros, s.maximum + zeros

    mins, maxs = [], []
    for s in spec:
        mn, mx = extract_min_max(s)
        mins.append(mn)
        maxs.append(mx)
    low = np.concatenate(mins, axis=0).astype(dtype)
    high = np.concatenate(maxs, axis=0).astype(dtype)
    assert low.shape == high.shape
    return spaces.Box(low, high, dtype=dtype)


def _flatten_obs(obs):
    obs_pieces = []
    for v in obs.values():
        flat = np.array([v]) if np.isscalar(v) else v.ravel()
        obs_pieces.append(flat)
    return np.concatenate(obs_pieces, axis=0)


class DMCWrapper(core.Env):
    def __init__(
        self,
        domain_name,
        task_name,
        task_kwargs=None,
        visualize_reward=False,
        from_pixels=False,
        height=84,
        width=84,
        camera_id=0,
        frame_skip=1,
        environment_kwargs=None,
        channels_first=True
    ):
#         assert 'random' in task_kwargs, 'please specify a seed, for deterministic behaviour'
        self._from_pixels = from_pixels
        self._height = height
        self._width = width
        self._camera_id = camera_id
        self._frame_skip = frame_skip
        self._channels_first = channels_first

        # create task
        self._env = suite.load(
            domain_name=domain_name,
            task_name=task_name,
            task_kwargs=task_kwargs,
            visualize_reward=visualize_reward,
            environment_kwargs=environment_kwargs
        )

        # true and normalized action spaces
        self._true_action_space = _spec_to_box([self._env.action_spec()], np.float32)
        self._norm_action_space = spaces.Box(
            low=-1.0,
            high=1.0,
            shape=self._true_action_space.shape,
            dtype=np.float32
        )

        # create observation space
        if from_pixels:
            shape = [3, height, width] if channels_first else [height, width, 3]
            self._observation_space = spaces.Box(
                low=0, high=255, shape=shape, dtype=np.uint8
            )
        else:
            self._observation_space = _spec_to_box(
                self._env.observation_spec().values(),
                np.float64
            )
            
        self._state_space = _spec_to_box(
            self._env.observation_spec().values(),
            np.float64
        )
        
        self.current_state = None

        # set seed
        self.seed(seed=996)

    def __getattr__(self, name):
        return getattr(self._env, name)

    def _get_obs(self, time_step):
        if self._from_pixels:
            obs = self.render(
                height=self._height,
                width=self._width,
                camera_id=self._camera_id
            )
            if self._channels_first:
                obs = obs.transpose(2, 0, 1).copy()
        else:
            obs = _flatten_obs(time_step.observation)
        return obs

    def _convert_action(self, action):
        action = action.astype(np.float64)
        true_delta = self._true_action_space.high - self._true_action_space.low
        norm_delta = self._norm_action_space.high - self._norm_action_space.low
        action = (action - self._norm_action_space.low) / norm_delta
        action = action * true_delta + self._true_action_space.low
        action = action.astype(np.float32)
        return action

    @property
    def observation_space(self):
        return self._observation_space

    @property
    def state_space(self):
        return self._state_space

    @property
    def action_space(self):
        return self._norm_action_space

    @property
    def reward_range(self):
        return 0, self._frame_skip

    def seed(self, seed):
        self._true_action_space.seed(seed)
        self._norm_action_space.seed(seed)
        self._observation_space.seed(seed)

    def step(self, action):
        assert self._norm_action_space.contains(action)
        action = self._convert_action(action)
        assert self._true_action_space.contains(action)
        reward = 0
        extra = {'internal_state': self._env.physics.get_state().copy()}

        for _ in range(self._frame_skip):
            time_step = self._env.step(action)
            reward += time_step.reward or 0
            done = time_step.last()
            if done:
                break
        obs = self._get_obs(time_step)
        self.current_state = _flatten_obs(time_step.observation)
        extra['discount'] = time_step.discount
        return obs, reward, done, extra

    def reset(self):
        time_step = self._env.reset()
        self.current_state = _flatten_obs(time_step.observation)
        obs = self._get_obs(time_step)
        return obs

    def render(self, mode='rgb_array', height=None, width=None, camera_id=0):
        assert mode == 'rgb_array', 'only support rgb_array mode, given %s' % mode
        height = height or self._height
        width = width or self._width
        camera_id = camera_id or self._camera_id
        return self._env.physics.render(
            height=height, width=width, camera_id=camera_id
        )

In [None]:
env = DMCWrapper("walker","run")

In [None]:
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
print(device)

In [None]:
import gymnasium as gym
import numpy as np
import torch
from pink import PinkNoiseDist
from pink import ColoredNoiseDist
from stable_baselines3 import SAC
import time
from tqdm import tqdm

# Define a function to evaluate an episode
def evaluate_episode(model, env):
    obs = env.reset()
    done = False
    total_reward = 0.0
    steps=0
    while steps<1000 and not done:
        action, _ = model.predict(obs, deterministic=True)
        obs, reward, done, _ = env.step(action)
        total_reward += reward
        steps+=1
    return total_reward

# Reproducibility
seed = 0
np.random.seed(seed)
torch.manual_seed(seed)
rng = np.random.default_rng(seed)

# Initialize environment

action_dim = env.action_space.shape[-1]
seq_len = 1000
rng = np.random.default_rng(0)

# Initialize agents
model_default = SAC("MlpPolicy", env, seed=seed)
model_pink = SAC("MlpPolicy", env, seed=seed)
model_OU = SAC("MlpPolicy", env, seed=seed)

# Set action noise
model_pink.actor.action_dist = PinkNoiseDist(seq_len, action_dim, rng=rng)
model_OU.actor.action_dist = ColoredNoiseDist(beta=2, seq_len=seq_len, action_dim=action_dim, rng=rng)

# Training parameters
total_timesteps = 1000000
eval_frequency = 10000 # Evaluate every 104 interactions
eval_rollouts = 5

wandb.init(
    project="Pinkie",
    config = {
    "Total_timesteps": total_timesteps,
    "Eval_frequency": eval_frequency,
    "Eval_rollouts": eval_rollouts
    }
)

#Final average performances
avg_default=0.0
avg_pink=0.0
avg_OU=0.0
final_default=0.0
final_pink=0.0
final_OU=0.0

# Train agents with evaluation
timesteps_so_far = 0
# for timesteps_so_far in tqdm(range(0, total_timesteps, eval_frequency)):
while timesteps_so_far < total_timesteps:
    t1 = time.time()
    # Train the default noise model
    model_default.learn(total_timesteps=eval_frequency)
    t2 = time.time()

    # Evaluate the default noise model
    mean_return_default = 0.0
    for _ in range(eval_rollouts):
        mean_return_default += evaluate_episode(model_default, env)
    mean_return_default /= eval_rollouts
    avg_default+=mean_return_default
    if(timesteps_so_far>=0.95*total_timesteps):
        final_default+=mean_return_default

    print(f"Return (Default): {mean_return_default}")
    print(f"Time taken (Default Model): {t2 - t1:.2f} seconds")
    print(f"Timesteps: {timesteps_so_far}, Mean Return: {mean_return_default}")

    t1=time.time()
    # Train the pink noise model
    model_pink.learn(total_timesteps=eval_frequency)
    # timesteps_so_far += eval_frequency
    t2 = time.time()

    # Evaluate the pink noise model
    mean_return_pink = 0.0
    for _ in range(eval_rollouts):
        mean_return_pink += evaluate_episode(model_pink, env)
    mean_return_pink /= eval_rollouts
    avg_pink+=mean_return_pink
    if(timesteps_so_far>=0.95*total_timesteps):
        final_pink+=mean_return_pink

    print(f"Return (Pink): {mean_return_pink}")
    print(f"Time taken (Pink Noise Model): {t2 - t1:.2f} seconds")
    print(f"Timesteps: {timesteps_so_far}, Mean Return: {mean_return_pink}")

    t1=time.time()
    # Train the pink noise model
    model_OU.learn(total_timesteps=eval_frequency)
    # timesteps_so_far += eval_frequency
    t2 = time.time()
    
    # Evaluate the pink noise model
    mean_return_OU = 0.0
    for _ in range(eval_rollouts):
        mean_return_OU += evaluate_episode(model_OU, env)
    mean_return_OU/= eval_rollouts
    avg_OU+=mean_return_OU
    if(timesteps_so_far>=0.95*total_timesteps):
        final_OU+=mean_return_OU

    print(f"Return (OU): {mean_return_OU}")
    print(f"Time taken (OU Noise Model): {t2 - t1:.2f} seconds")
    print(f"Timesteps: {timesteps_so_far}, Mean Return: {mean_return_OU}")

    timesteps_so_far += eval_frequency
    
    wandb.log({
        "mean_return_OU": mean_return_OU,
        "mean_return_pink": mean_return_pink,
        "mean_return_default": mean_return_default,
        "timesteps_so_far": timesteps_so_far
    })

avg_default/=(total_timesteps/eval_frequency)
avg_pink/=(total_timesteps/eval_frequency)
avg_OU/=(total_timesteps/eval_frequency)

final_default/=(0.05*total_timesteps/eval_frequency)
final_pink/=(0.05*total_timesteps/eval_frequency)
final_OU/=(0.05*total_timesteps/eval_frequency)

wandb.log({
    "final_default": final_default,
    "final_pink": final_pink,
    "final_OU": final_OU,
    "avg_default": avg_default,
    "avg_pink": avg_pink,
    "avg_OU": avg_OU
})

print("Mean:")
print(f"White:{avg_default}           Pink:{avg_pink}             OU:{avg_OU}")
print("Final:")
print(f"White:{final_default}           Pink:{final_pink}             OU:{final_OU}")