In this notebook we will test if our audio wrappers are working correctly. The wrappers to be tested are RetroSound and FFTWrapper. By testing them, we will make sure that the agent receives the audio and can learn from it.
To do it, we will make a simple toy environment, where to solve the problem, the agents needs to process audio in a sensible way.

We will code a 1D game with two potential goals, picked at random on each episode. Agent receives a sine wave with frequency propotional to distance to the correct goal. Without audio info it will learn to pick goals at random or always pick one of them. With sound info it should learn to go correct goal all the time.
Here is what the environment will look like:

\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_ <br>
| 0  1  2  3  4  5  6  7  8 |<br>
¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯
 
We start the game at a random position [2-6]. We can move left or right, we get a negative reward of -0.025 on each step. The correct goal is either 0 or 8, which will give us a reward of 1. If we reach the wrong goal, we will get a reward -1 and the game will finish. The closer we are to the correct goal, the higher is the frequency of sin wave.

# Preparations

Works with colab, adjust for conda/widnows/etc... if needed

In [None]:
# Install libraries
!pip install -q gym-retro

# Get my version of stable-baselines3 with audio support
!git clone -b gym-to-retro https://github.com/rienath/stable-baselines3.git
!mv stable-baselines3/stable_baselines3 stable_baselines3
!rm -rf stable-baselines3/

Imports

In [None]:
from typing import Dict, Union
import gym
import numpy as np
import stable_baselines3
from stable_baselines3.ppo import PPO
from stable_baselines3.common.atari_wrappers import RetroSound, FFTWrapper
from stable_baselines3.common.type_aliases import GymStepReturn

Define the classes that we are going to use

In [None]:
class emulator():
    """
    This class models retro emulator that can return audio to make env.em.get_audio() possible in the wrappers.
    """
    def __init__(self, audio):
        self.audio = audio

    def get_audio(self):
        return self.audio
    
    def update_audio(self, audio):
        self.audio = audio 


class AudioMultiObsEnv(gym.Env):
    """
    GridWorld-based MultiObs Environments 1x8.
        ___________________________
       | 0  1  2  3  4  5  6  7  8 |
       ¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯
    start is 4 or random
    goal is 0 or 8, randomly determined in every round
    if we reach incorrect goal, we get punished (reward -1)
    if we reach correct goal, we get rewarded (reward -1)
    actions are = [left, right]
    each state is represented by a random image
    each state has stereo audio; the closer state is to goal, the higher the frequency
    each step has a punishment (reward -0.025)
    :param num_col: Number of columns in the grid
    :param random_start: If true, agent starts in random position apart from the goal/false-goal positions and states directly next to them
    """

    def __init__(
        self,
        num_col: int = 9,
        random_start: bool = True
    ):
        super(AudioMultiObsEnv, self).__init__()

        self.img_size = [84, 84, 1]

        self.random_start = random_start
        self.action_space = gym.spaces.Discrete(2)

        self.observation_space = gym.spaces.Dict(
            spaces={
                "obs": gym.spaces.Box(0, 255, self.img_size, dtype=np.uint8),
            }
        )

        self.count = 0
        self.max_count = 20 # Number of maximum steps
        self.log = ""
        self.state = 4
        self.action2str = ["left", "right"]
        self.init_possible_transitions()
        self.num_col = num_col
        self.state_mapping = []
        self.sounds = []
        self.init_state_mapping()
        self.min_state = 0
        self.max_state = num_col - 1
        self.win_state = np.random.choice([self.min_state, self.max_state])
        self.update_state_mapping()
        self.em = emulator(self.get_state_mapping()["sound"]) # Makes env.em.get_audio possible for wrappers


    def init_state_mapping(self) -> None:
        """
        Initializes the state_mapping array which holds the observation values for each state
        """

        # Each state is represented by a random image
        col_imgs = np.random.randint(0, 255, (self.num_col, 84, 84), dtype=np.int32)

        # Each state is represented by sin wave of a certain frequency. The closer the state is to the goal
        # the higher the frequency.
        for i in range(self.num_col):
            # Get x values of the sin wave
            time = np.arange(0, 524, 1)
            frequency = i*100
            # Amplitude of the sine wave is sine of freauency * time
            amplitude = np.sin(frequency*time)
            # Get the amplitude to format [[l, r], [l, r]...] and so on as this is the representation
            # of stereo audio in retro.
            stereo_waves = np.transpose([amplitude, amplitude])
            self.sounds.append(stereo_waves)
        
        # Arrange sounds as if goal is the last state by default
        for i in range(self.num_col):
            self.state_mapping.append({"obs":   col_imgs[i].reshape(self.img_size),
                                       "sound": self.sounds[i]})
    
    def update_state_mapping(self):
        # If win state is the minimal state, leave sounds be.
        # If not, reverse them as the goal is in the other end of the tunnel.
        if self.win_state == self.min_state:
            for i in range(self.num_col):
                self.state_mapping[i]["sound"] = self.sounds[i]
        else:
            sounds_reversed = self.sounds[::-1]
            for i in range(self.num_col):
                self.state_mapping[i]["sound"] = sounds_reversed[i]

    def get_state_mapping(self) -> Dict[str, np.ndarray]:
        """
        Uses the state to get the observation mapping.
        :return: observation dict
        """
        return self.state_mapping[self.state]

    def init_possible_transitions(self) -> None:
        """
        Initializes the transitions of the environment.
        """
        self.left_possible = [1, 2, 3, 4, 5, 6, 7, 8]
        self.right_possible = [0, 1, 2, 3, 4, 5, 6, 7]
        self.down_possible = []
        self.up_possible = []

    def step(self, action: Union[int, float, np.ndarray]) -> GymStepReturn:
        """
        Run one timestep of the environment's dynamics. When end of
        episode is reached, you are responsible for calling `reset()`
        to reset this environment's state.
        Accepts an action and returns a tuple (observation, reward, done, info).
        :param action:
        :return: tuple (observation, reward, done, info).
        """
        action = int(action)

        self.count += 1

        prev_state = self.state

        reward = -0.025
        # Define state transition
        if self.state in self.left_possible and action == 0:  # left
            self.state -= 1
        elif self.state in self.right_possible and action == 1:  # right
            self.state += 1

        got_to_end = self.state == self.max_state or self.state == self.min_state
        if got_to_end:
            reward = 1 if self.state == self.win_state else -1
        done = self.count > self.max_count or got_to_end

        self.log = f"Went {self.action2str[action]} in state {prev_state}, got to state {self.state}"

        # Update audio in emulator
        self.em.update_audio(self.get_state_mapping()["sound"])

        return {"obs": self.get_state_mapping()["obs"]}, reward, done, {"got_to_end": got_to_end}

    def render(self, mode: str = "human") -> None:
        """
        Prints the log of the environment.
        :param mode:
        """
        print(self.log)

    def reset(self) -> Dict[str, np.ndarray]:
        """
        Resets the environment state and step count and returns reset observation.
        """
        self.count = 0
        self.win_state = np.random.choice([self.min_state, self.max_state])
        # Update frequencies of states as the goal might have changed.
        self.update_state_mapping()
        if not self.random_start:
            self.state = 4
        else:
            self.state = np.random.randint(2, self.max_state - 2)
        # Update audio in emulator
        self.em.update_audio(self.get_state_mapping()["sound"])
        return {"obs": self.get_state_mapping()["obs"]}

# Testing

In [None]:
# Load tensorboard
%load_ext tensorboard
%tensorboard --logdir ./AUDIO_TEST/

## Just image

In [None]:
env = AudioMultiObsEnv()
model = PPO("MultiInputPolicy", env, verbose=0, tensorboard_log="./AUDIO_TEST/")
model.learn(total_timesteps=200000)
env.close()

## Image + raw audio

In [None]:
env = AudioMultiObsEnv()
env = RetroSound(env)
model = PPO("MultiInputPolicy", env, verbose=0, tensorboard_log="./AUDIO_TEST/")
model.learn(total_timesteps=200000)
env.close()

## Image + FFT audio

In [None]:
env = AudioMultiObsEnv()
env = RetroSound(env)
env = FFTWrapper(env)
model = PPO("MultiInputPolicy", env, verbose=0, tensorboard_log="./AUDIO_TEST/")
model.learn(total_timesteps=200000)
env.close()