# Reinforcement Learning

Exercise of spoken language acquisition described in the paper:

> M. Zhang, T. Tanaka, W. Hou, S. Gao, T. Shinozaki,
"[Sound-Image Grounding Based Focusing Mechanism for Efficient Automatic Spoken Language Acquisition](http://www.interspeech2020.org/uploadfile/pdf/Thu-2-4-4.pdf),"
in *Proc. Interspeech*, 2020.

The details differ slightly from the original paper.

## Table of contents
1. Definition of food
1. Dialogue partner
1. SpoLacq environment
1. Prepare ASR model
1. Make sound dictionary
1. RL environment and model creation
1. Train RL model
1. Save and load RL model
1. Retrain RL model
1. Test the learnt agent

In [None]:
from IPython.display import Audio
from glob import glob
import os
import pickle
import random
import sys
from typing import Callable, List, Tuple, Union

import gym
import librosa
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from scipy.spatial.distance import cdist
from sklearn.cluster import KMeans
from tqdm import tqdm

os.chdir("..")
sys.path.append("utils")

from utils.sb3_api import CustomDQNPolicy, CustomDQN
from utils.wav2vec2_api import ASR

## Definition of food

In [None]:
class Food():
    """
    Definition of food.
    It has name, image, and the mean RGB of an image.
    """
    
    def __init__(self, name: str, image: np.ndarray, RGB: np.ndarray):
        self.name = name
        self.image = image
        self.RGB = RGB

## Dialogue partner

<img src="../fig/env.png" width=600>

In [None]:
class DialogWorld:
    """
    Dialogue partner for the language learning agent,
    which is the outside environment.
    An agent receives randomly chosen two images of food from DialogWorld.
    The DialogWorld recognizes an agent's utterance and returns feedback.
    
    :param MAX_STEP: Number of dialogue turns per episode.
    :param hasNO: If an agent has the option of not eating either of two foods,
        set True.
    :param FOODS: Tuple of food names. The type of food names must match
        the type of the return value of ASR. For example, Wav2Vec2 returns
        a space-delimited sequence of uppercase letters, e.g., GREEN PEPPER.
    :param datadir: A directory of food images.
    :param asr: If args.use_real_time_asr==True, you can use any ASR of
        Callable[[np.ndarray], str]; otherwise, it is the identity function,
        i.e., lambda x: x.
    """
    
    def __init__(
        self,
        MAX_STEP: int,
        hasNO: bool,
        FOODS: tuple,
        datadir: str,
        asr: Callable[[Union[np.ndarray, str]], str],
    ):
        print("Initializaing DialogWorld")
        self.MAX_STEP = MAX_STEP
        self.hasNO = hasNO
        self.FOODS = FOODS
        self.make_food_storage(datadir)
        self.asr = asr
        print("Have prepared ASR")
        self.reset()
    
    def make_food_storage(self, datadir: str) -> None:
        self.food_storage = list() # list of food that the environment has
        for f in self.FOODS:
            if f == "NO": continue
            image_paths = glob(f"{datadir}/{f.replace(' ', '_').lower()}/test*/*.jpg")
            assert len(image_paths) == 30, "mismatch in test image dataset"
            for image_path in image_paths:
                image = Image.open(image_path)
                image = image.resize((224, 224))
                image = np.array(image)
                RGB = np.mean(image.astype(float)/255, axis=(0,1))
                self.food_storage.append(Food(f, image, RGB))
    
    def reset(self) -> None:
        self.num = self.MAX_STEP
        self.leftfood = random.choice(self.food_storage)
        self.rightfood = random.choice(self.food_storage)
        self.done = False
    
    def step(self, action: Union[np.ndarray, str]) -> Tuple[Tuple[bool, int], bool]:
        text = self.asr(action)
        if self.hasNO:
            correct_answer = ["NO", self.leftfood.name, self.rightfood.name]
        else:
            correct_answer = [self.leftfood.name, self.rightfood.name]
        if text in correct_answer:
            dlg_success = True
            foodID = self.FOODS.index(text)
        else:
            dlg_success = False
            foodID = -1
        self.num -= 1
        if self.num == 0:
            self.done = True
        self.leftfood = random.choice(self.food_storage)
        self.rightfood = random.choice(self.food_storage)
        return (dlg_success, foodID), self.done
    
    def observe(self) -> dict:
        return dict(num=self.num, leftfood=self.leftfood, rightfood=self.rightfood)

## SpoLacq environment

In [None]:
class SpoLacq1(gym.Env):
    """
    RL environment described in the paper
    
    M. Zhang, T. Tanaka, W. Hou, S. Gao, T. Shinozaki,
    "Sound-Image Grounding Based Focusing Mechanism for Efficient Automatic Spoken Language Acquisition,"
    in Proc. Interspeech, 2020.
    
    Internal state of spolacq agent is inside in this environment.
    An agent has a preferred color (RGB) as its internal state.
    An agent wants food that is close to that color.
    An agent receives randomly chosen two images of food from DialogWorld.
    An agent is rewarded when it speaks the name of the food it wants.
    
    :param FOODS: Tuple of food names. The type of food names must match
        the type of the return value of ASR. For example, Wav2Vec2 returns
        a space-delimited sequence of uppercase letters, e.g., GREEN PEPPER.
    :param datadir: A directory of food images.
    :param sounddic: If args.use_real_time_asr==True, it is an object of
        List[np.ndarray]; otherwise, it is an object of List[str].
    :param asr: If args.use_real_time_asr==True, you can use any ASR of
        Callable[[np.ndarray], str]; otherwise, it is the identity function,
        i.e., lambda x: x.
    """
    
    MAX_RGB = 1 # normalized image
    
    def __init__(
        self,
        FOODS: tuple,
        datadir: str,
        sounddic: Union[List[np.ndarray], List[str]],
        asr: Callable[[Union[np.ndarray, str]], str],
    ):
        super().__init__()
        self.dlgworld = DialogWorld(1, False, FOODS, datadir, asr)
        self.action_space = gym.spaces.Discrete(len(sounddic))
        self.observation_space = gym.spaces.Dict(
            dict(
                state=gym.spaces.Box(low=0, high=self.MAX_RGB, shape=(3,)),
                leftfoodRGB=gym.spaces.Box(low=0, high=self.MAX_RGB, shape=(3,)),
                rightfoodRGB=gym.spaces.Box(low=0, high=self.MAX_RGB, shape=(3,)),
                leftimage=gym.spaces.Box(low=0, high=255, shape=(224, 224, 3), dtype=np.uint8),
                rightimage=gym.spaces.Box(low=0, high=255, shape=(224, 224, 3), dtype=np.uint8),
                step=gym.spaces.Discrete(self.dlgworld.MAX_STEP+1),
                leftfoodID=gym.spaces.Discrete(len(FOODS)),
                rightfoodID=gym.spaces.Discrete(len(FOODS)),
            )
        )
        # Read sound files for sound dictionary
        self.sounddic = sounddic # convert categorical ID to wave utterance
        self.FOODS = FOODS
        self.succeeded_log = list()
        self.reset()
    
    def reset(self) -> dict:
        # Initialize the dialogue world
        self.dlgworld.reset()
        # Initialize the internal state of spolacq agent
        self.preferredR = random.random()
        self.preferredG = random.random()
        self.preferredB = random.random()
        return self.observe()
    
    def step(self, action: int) -> Tuple[dict, int, bool, dict]:
        """:param action: A number in self.action_space"""
        old_state = self.observe()
        utterance = self.sounddic[action]
        feedback, dlg_done = self.dlgworld.step(utterance)
        self.update_internal_state(feedback)
        new_state = self.observe()
        reward = self.reward(old_state, new_state, feedback)
        if reward > 0: self.succeeded_log.append(action)
        return new_state, reward, dlg_done, dict()
    
    def observe(self) -> dict:
        insideobs = np.array([self.preferredR, self.preferredG, self.preferredB])
        outsideobs = self.dlgworld.observe()
        return dict(
            state=insideobs,
            leftfoodRGB=outsideobs["leftfood"].RGB,
            rightfoodRGB=outsideobs["rightfood"].RGB,
            leftimage=outsideobs["leftfood"].image,
            rightimage=outsideobs["rightfood"].image,
            step=outsideobs["num"],
            leftfoodID=self.FOODS.index(outsideobs["leftfood"].name),
            rightfoodID=self.FOODS.index(outsideobs["rightfood"].name),
        )
    
    def update_internal_state(self, feedback: Tuple[bool, int]) -> None:
        # Initialize the internal state of spolacq agent
        self.preferredR = random.random()
        self.preferredG = random.random()
        self.preferredB = random.random()
    
    def reward(self, old_state: dict, new_state: dict, feedback: Tuple[bool, int]) -> int:
        if not feedback[0]: # failed dialogue
            return 0
        leftfood_distance = np.linalg.norm(old_state["state"] - old_state["leftfoodRGB"])
        rightfood_distance = np.linalg.norm(old_state["state"] - old_state["rightfoodRGB"])
        if leftfood_distance < rightfood_distance and feedback[1] == old_state["leftfoodID"]:
            return 1 # Agent want leftfood and agent's utterance is leftfood's name
        elif leftfood_distance >= rightfood_distance and feedback[1] == old_state["rightfoodID"]:
            return 1 # Agent want rightfood and agent's utterance is rightfood's name
        else:
            return 0
    
    def render(self, mode: str = "console", close=False) -> None:
        state = self.observe()
        leftfood_distance = np.linalg.norm(state["state"] - state["leftfoodRGB"])
        rightfood_distance = np.linalg.norm(state["state"] - state["rightfoodRGB"])
        if leftfood_distance < rightfood_distance:
            print(f"The left {self.FOODS[state['leftfoodID']]} is preferred.")
        else:
            print(f"The right {self.FOODS[state['rightfoodID']]} is preferred.")
        if mode == "human":
            plt.subplot(1, 2, 1)
            plt.imshow(state["leftimage"])
            plt.subplot(1, 2, 2)
            plt.imshow(state["rightimage"])
    
    def close(self) -> None:
        pass
    
    def seed(self, seed=None) -> None:
        random.seed(seed)

## Data path

In [None]:
class Args:
    def __init__(self):
        self.datadir = "data/dataset"
        self.workdir = "exp1/unsup_rl/test"
        self.load_state_dict_path = self.workdir + "/unsup_backend.pth.tar"
        self.segment_pkl = self.workdir + "/seg_audios.txt"
        self.image_nnfeat = self.workdir + "/data/train_img_fea_sim_8type_clean_limited_data.npy"
        self.segment_nnfeat = self.workdir + "/data/new720db20k2_fea_sim_8type_clean_limited_data.npy"


args = Args()

## Food type

In [None]:
FOODS = (
    "CHERRY",
    "GREEN PEPPER",
    "LEMON",
    "ORANGE",
    "POTATO",
    "STRAWBERRY",
    "SWEET POTATO",
    "TOMATO",
)

## Prepare ASR model

ASR class is defined at [../utils/wav2vec2_api.py](../utils/wav2vec2_api.py).

In [None]:
asr = ASR("facebook/wav2vec2-large-robust-ft-libri-960h")

## Make sound dictionary

<img src="../fig/corr.png" width=600>

In [None]:
# Focusing mechanism
image_features = np.load(args.image_nnfeat).squeeze()
segment_features = np.load(args.segment_nnfeat).squeeze()
kmeans = KMeans(n_clusters=120, random_state=2).fit(image_features)
similarity = -cdist(kmeans.cluster_centers_, segment_features)
focused_segment_ids = similarity.argsort(axis=1)[:, -100:].flatten()

with open("exp1/unsup_rl/test/sounddic.pkl", "rb") as f:
    sounddic = pickle.load(f)

## RL environment and model creation

Custom policy network is defined at [../utils/sb3_api.py](../utils/sb3_api.py).

<img src="../fig/dqn.png" width=600>

In [None]:
env = SpoLacq1(FOODS, args.datadir, sounddic, asr)

# RL learning model creation
policy_kwargs = dict(
    features_extractor_kwargs=dict(
        cluster_centers=kmeans.cluster_centers_,
        load_state_dict_path=args.load_state_dict_path,
        num_per_group=100,
        purify_rate=0.97,
    )
)
replay_buffer_kwargs = dict(handle_timeout_termination=False)

model = CustomDQN(
    CustomDQNPolicy,
    env,
    policy_kwargs=policy_kwargs,
    replay_buffer_kwargs=replay_buffer_kwargs,
    tensorboard_log="./spolacq_tmplog/",
    buffer_size=1000,
    learning_starts=50,
    verbose=1,
)

## Train RL model

In [None]:
# To avoid CUDA out of memory, total_timesteps is set to 10
model.learn(total_timesteps=10)

## Save RL model and replay buffer

In [None]:
# model.save(args.workdir+'/dqn')
# model.save_replay_buffer(args.workdir+'/replay_buffer')

## Load RL model and replay buffer

In [None]:
del model
del env
env = SpoLacq1(FOODS, args.datadir, sounddic, asr)
model = CustomDQN.load(args.workdir+"/dqn", env=env)
model.load_replay_buffer(args.workdir+"/replay_buffer")

## Retrain RL model
If reset_num_timesteps=False, we can continue the previous tensorboard's log.

In [None]:
# model.learn(total_timesteps=50000, reset_num_timesteps=False)

## Test the learnt agent

In [None]:
state = env.reset()
env.render(mode="human") # render the state

# Agent gets an environment state and returns a decided action
action, _ = model.predict(state, deterministic=True)

# Environment gets an action from the agent, proceeds the time step,
# and returns the new state and reward etc.
state, reward, done, info = env.step(action)
print(f"utterance: {asr(sounddic[action])}, reward: {reward}\n")
Audio(sounddic[action], rate=16000)