In [3]:
import random
import torch
from torch.distributions import MultivariateNormal
import numpy as np
import torch.nn as nn
from torch.nn import functional as F
import gensim
import gensim.downloader as glove_api
import os
import io
import time

from matplotlib import pyplot as pl
import pickle

from ZorkGym.text_utils.text_parser import BagOfWords, Word2Vec, TextParser, tokenizer, BasicParser
from agents.OMP_DDPG import OMPDDPG

In [4]:
from __future__ import division, print_function
import sys
from scipy.linalg import norm
from math import sqrt
from sklearn.base import BaseEstimator
from sklearn.datasets.base import Bunch
from sklearn.metrics import roc_auc_score
from hashlib import sha1

In [5]:
"""
Module implementing the FISTA algorithm
"""
__author__ = 'Jean KOSSAIFI'


def mixed_norm(coefs, p, q=None, n_samples=None, n_kernels=None):
    """ Computes the (p, q) mixed norm of the vector coefs

    Parameters
    ----------
    coefs : ndarray
        a vector indexed by (l, m)
        with l in range(0, n_kernels)
            and m in range(0, n_samples)

    p : int or np.inf

    q : int or np.int

    n_samples : int, optional
        number of elements in each kernel
        default is None

    n_kernels : int, optional
        number of kernels
        default is None

    Returns
    -------
    float
    """
    if q is None or p == q:
        return norm(coefs, p)
    else:
        return norm([norm(i, p) for i in coefs.reshape(
            n_kernels, n_samples)], q)


def dual_mixed_norm(coefs, n_samples, n_kernels, norm_):
    """ Returns a function corresponding to the dual mixt norm

    Parameters
    ----------
    coefs : ndarray
        a vector indexed by (l, m)
        with l in range(0, n_kernels)
            and m in range(0, n_samples)

    n_samples : int, optional
        number of elements in each kernel
        default is None

    n_kernels : int, optional
        number of kernels
        default is None

    norm_ : {'l11', 'l12', 'l21', 'l22'}
        the dual mixed norm we want to compute

    Returns
    -------
    float
    """
    if norm_ == 'l11':
        res = norm(coefs, np.inf)
    elif norm_ == 'l12':
        res = mixed_norm(coefs, np.inf, 2, n_samples, n_kernels)
    elif norm_ == 'l21':
        res = mixed_norm(coefs, 2, np.inf, n_samples, n_kernels)
    else:
        res = norm(coefs, 2)
    return res


def by_kernel_norm(coefs, p, q, n_samples, n_kernels):
    """ Computes the (p, q) norm of coefs for each kernel

    Parameters
    ----------
    coefs : ndarray
        a vector indexed by (l, m)
        with l in range(0, n_kernels)
            and m in range(0, n_samples)

    p : int or np.inf

    q : int or np.inf

    n_samples : int, optional
        number of elements in each kernel
        default is None

    n_kernels : int, optional
        number of kernels
        default is None

    Returns
    -------
    A list of the norms of the sub vectors associated to each kernel
    """
    return [mixed_norm(i, p, q, n_samples, 1)
            for i in coefs.reshape(n_kernels, n_samples)]


def prox_l11(u, lambda_):
    """ Proximity operator for l(1, 1, 2) norm

    

    :math:`\\hat{\\alpha}_{l,m} = sign(u_{l,m})\\left||u_{l,m}| - \\lambda \\right|_+`

    Parameters
    ----------
    u : ndarray
        The vector (of the n-dimensional space) on witch we want
        to compute the proximal operator

    lambda_ : float
        regularisation parameter

    Returns
    -------
    ndarray : the vector corresponding to the application of the
             proximity operator to u

    """
    return np.sign(u) * np.maximum(np.abs(u) - lambda_, 0.)

def prox_l22(u, lambda_):
    """ proximity operator l(2, 2, 2) norm

    Parameters
    ----------

     u : ndarray
        The vector (of the n-dimensional space) on witch we want to compute the proximal operator

    lambda_ : float
        regularisation parameter

    Returns
    -------

    ndarray : the vector corresponding to the application of the proximity operator to u

    Notes
    -----

    :math:`\\hat{\\alpha}_{l,m} = \\frac{1}{1 + \\lambda} \\, u_{l,m}`

    """
    return 1./(1.+lambda_)*u

def prox_l21_1(u, l, n_samples, n_kernels):
    """ Proximity operator l(2, 1, 1) norm

    Parameters
    ----------
    u : ndarray
        The vector (of the n-dimensional space) on witch we want to compute the proximal operator

    lambda_ : float
        regularisation parameter
    
    n_samples : int, optional
        number of elements in each kernel
        default is None

    n_kernels : int, optional
        number of kernels
        default is None

    Returns
    -------
    ndarray : the vector corresponding to the application of the proximity operator to u


    Notes
    -----
    
    .. math::

       \hat{\alpha}_{l,m} = u_{l,m} \left| 1 - \frac{\lambda}{\|u_{l \bullet}\|_{2}} \right|_+\

    where l is in range(0, n_samples) and m is in range(0, n_kernels)
    so :math:`u_{l\\bullet}` = [u(l, m) for m in n_kernels]

    """
    return (u.reshape(n_kernels, n_samples) *\
        [max(1. - l/norm(u[np.arange(n_kernels)*n_samples+i], 2), 0.)
            for i in range(n_samples)]).reshape(-1)


def prox_l21(u, l, n_samples, n_kernels):
    """ proximity operator l(2, 1, 2) norm

    Parameters
    ----------
    u : ndarray
        The vector (of the n-dimensional space) on witch we want to compute the proximal operator

    lambda_ : float
        regularisation parameter

    n_samples : int, optional
        number of elements in each kernel
        default is None

    n_kernels : int, optional
        number of kernels
        default is None


    Returns
    -------
    ndarray : the vector corresponding to the application of the proximity operator to u

    Notes
    -----

    :math:`\\hat{\\alpha}_{l,m} = u_{l,m} \\left| 1 - \\frac{ \\lambda}{ \\|u_{l \\bullet }\\|_{2}} \\right|_+`

    where l is in range(0, n_kernels) and m is in range(0, n_samples)
    so :math:`u_{l \\bullet }` = [u(l, m) for l in n_samples]

    """
    for i in u.reshape(n_kernels, n_samples):
        n = norm(i, 2)
        if n==0 or n==np.Inf:
            i[:] = 0
        else:
            i[:] *=  max(1. - l/n, 0.)
        # !! If you do just i *= , u isn't modified
        # The slice is needed here so that the array can be modified
    return u


def prox_l12(u, l, n_samples, n_kernels):
    """ proximity operator for l(1, 2, 2) norm

    Parameters
    ----------
    u : ndarray
        The vector (of the n-dimensional space) on witch we want to compute the proximal operator

    lambda_ : float
        regularisation parameter

    n_samples : int, optional
        number of elements in each kernel
        default is None

    n_kernels : int, optional
        number of kernels
        default is None

    Returns
    -------
    ndarray : the vector corresponding to the application of the proximity operator to u


    Notes
    -----

    :math:`\\hat{\\alpha}_{l,m} = sign(u_{l,m})\\left||u_{l,m}| - \\frac{\\lambda \\sum\\limits_{m_l=1}^{M_l} u2_{l,m_l}}{(1+\\lambda M_l) \\|u_{l \\bullet }\\|_{2}} \\right|_+`

    where  :math:`u2_{l,m_l}`  denotes the :math:`|u_{l,m_l}|`
        ordered  by descending  order for fixed  :math:`l`,  and the
            quantity :math:`M_l` is the number computed in compute_M

    """
    for i in u.reshape(n_kernels, n_samples):
        Ml, sum_Ml = compute_M(i, l, n_samples)
        # i[:] so that u is really modified
        n = norm(i, 2)
        if n == 0 or n == np.Inf:
            i[:] = 0
        else:
            i[:] = np.sign(i)*np.maximum(
                np.abs(i)-(l*sum_Ml)/((1.+l*Ml)*n), 0.)
    return u

def compute_M(u, lambda_, n_samples):
    """
    Parameters
    ----------
    u : ndarray 
        ndarray of size (n_samples * n_samples) representing a subvector of K,
        ie the samples for a single kernel

    lambda_ : int

    n_samples : int
        number of elements in each kernel 
        ie number of elements of u

    Notes
    -----
    
    :math:`M_l` is the number such that

    :math:`u2_{l,M_l+1} \\leq  \\lambda \\sum_{m_l=1}^{M_l+1} \\left( u2_{l,m_l} - u2_{l,M_l+1}\\right)`

    and


    :math:`u2_{l,M_l} > \\lambda\\sum_{m_l=1}^{M_l} \\left( u2_{l,m_l} - u2_{k,M_l}\\right)`

    Detailed explication
    
    let u denotes |u(l)|, the vector associated with the kernel l, ordered by descending order
    Ml is the integer such that
        u(Ml) <= l * sum(k=1..Ml + 1) (u(k) - u(Ml + 1))    (S1)
        and
        u(Ml) > l * sum(k=1..Ml) (u(k) - u(Ml)              (S2)
    Note that in that definition, Ml is in [1..Ml]
    In python, while Ml is in [1..(Ml-1)], indices will be in [0..(Ml-1)], so we must take care of indices.
    That's why, we consider Ml is in [0..(Ml-1)] and, at the end, we add 1 to the result

    Detailed example

    if u(l) = [0 1 2 3] corrsponds to the vector associated with a kernel
        then u = |u(l)| ordered by descending order ie u = [3 2 1 0]

    Then u = [3 2 1 0]
    let l = 1
    Ml is in {0, 1, 2} (not 3 because we also consider Ml+1)
    # Note : in fact Ml is in {1, 2, 3} but it is more convenient
    # to consider it is in {0, 1, 2} as indexing in python starts at 0
    # We juste have to add 1 to the final result

    if Ml = 0 then S1 = 1 and S2 = 0
    if Ml = 1 then S1 = 3 and S2 = 1
    if Ml = 2 then S1 = 6 and S2 = 3

    if Ml = 0 then u(Ml+1)=u(1)=2  > l*... =1  (S1 is not verified)
              and  u(Ml)=u(0)=3    > l*... =0  (S2 is verified)

    if Ml = 1 then u(Ml+1)=u(2)=1 <= l*... =3  (S1 is verified)
              and  u(Ml)=u(1)=2    > l*... =1  (S2 is verified)

    if Ml = 2 then u(Ml+1)=u(3)=0 <= l*... =6  (S1 is verified)
              but  u(Ml)=u(2)=1   <= l*... =3  (S1 is not verified)

    Conclusion : Ml = 1 + 1 !!
    Ml = 2 because in python, indexing starts at 0, so Ml +1

    """
    u = np.sort(np.abs(u))[::-1]
    S1 = u[1:] - lambda_*(np.cumsum(u)[:-1] - (np.arange(n_samples-1)+1)*u[1:])
    S2 = u[:-1] - lambda_*(np.cumsum(u)[:-1] - (np.arange(n_samples-1)+1)*u[:-1])
    Ml = np.argmax((S1<=0.) & (S2>0.)) + 1

    return Ml, np.sum(u[:Ml]) # u[:Ml] = u[0, 1, ..., Ml-1] !!


def hinge_step(y, K, Z):
    """
    Returns the point in witch we apply gradient descent

    parameters
    ----------
    y : np-array
        the labels vector

    K : 2D np-array
        the concatenation of all the kernels, of shape
        n_samples, n_kernels*n_samples

    Z : a linear combination of the last two coefficient vectors

    returns
    -------
    res : np-array of shape n_samples*,_kernels
          a point of the space where we will apply gradient descent
    """
    return np.dot(K.transpose(), np.maximum(1 - np.dot(K, Z), 0))

def least_square_step(y, K, Z):
    """
    Returns the point in witch we apply gradient descent

    parameters
    ----------
    y : np-array
        the labels vector

    K : 2D np-array
        the concatenation of all the kernels, of shape
        n_samples, n_kernels*n_samples

    Z : a linear combination of the last two coefficient vectors

    returns
    -------
    res : np-array of shape n_samples*,_kernels
          a point of the space where we will apply gradient descent
    """
    return np.dot(K.transpose(), y - np.dot(K,Z))


def _load_Lipschitz_constant(K):
    """ Loads the Lipschitz constant and computes it if not already saved

    Parameters
    ----------
    K : 2D-ndarray
        The matrix of witch we want to compute the Lipschitz constant

    Returns
    -------
    float

    Notes
    -----
    Lipshitz constant is just a number < 2/norm(np.dot(K, K.T), 2)

    The constant is stored in a npy hidden file, in the current directory.
    The filename is the sha1 hash of the ndarray

    """
    try:
        mu = np.load('./.%s.npy' % sha1(K).hexdigest())
    except:
        mu = 1/norm(np.dot(K, K.transpose()), 2)
        np.save('./.%s.npy' % sha1(K).hexdigest(), mu)
    return mu
    

class Fista(BaseEstimator):
    """

    Fast iterative shrinkage/thresholding Algorithm

    Parameters
    ----------

    lambda_ : int, optionnal
        regularisation parameter
        default is 0.5

    loss : {'squared-hinge', 'least-square'}, optionnal
        the loss function to use
        defautl is 'squared-hinge'
        
    penalty : {'l11', 'l22', 'l12', 'l21'}, optionnal
        norm to use as penalty
        default is l11

    n_iter : int, optionnal
        number of iterations
        default is 1000

    recompute_Lipschitz_constant : bool, optionnal
        if True, the Lipschitz constant is recomputed everytime
        if False, it is stored based on it's sha1 hash
        default is False

    """
    
    def __init__(self, lambda_=0.5, loss='squared-hinge', penalty='l11', n_iter=1000, recompute_Lipschitz_constant=False):
        self.n_iter = n_iter
        self.lambda_ = lambda_
        self.loss = loss
        self.penalty = penalty
        self.p = int(penalty[1])
        self.q = int(penalty[2])
        self.recompute_Lipschitz_constant = recompute_Lipschitz_constant

    def fit(self, K, y, Lipschitz_constant=None,  verbose=0, **params):
        """ Fits the estimator

        We want to solve a problem of the form y = KB + b
            where K is a (n_samples, n_kernels*n_samples) matrix.

        Parameters
        ---------
        K : ndarray
            numpy array of shape (n, p)
            K is the concatenation of the p/n kernels
                where each kernel is of size (n, n)

        y : ndarray
            an array of the labels to predict for each kernel
            y is of size p
                where K.shape : (n, p)

        Lipschitz_constant : float, optionnal
             allow the user to pre-compute the Lipschitz constant
             (its computation can be very slow, so that parameter is very
             usefull if you were to use several times the algorithm on the same data)

        verbose : {0, 1}, optionnal
            verbosity of the method : 1 will display informations while 0 will display nothing
            default = 0

        Returns
        -------
        self
        """
        next_step = hinge_step
        if self.loss=='squared-hinge':
            K = y[:, np.newaxis] * K
            # Equivalent to K = np.dot(np.diag(y), X) but faster
        elif self.loss=='least-square':
            next_step = least_square_step

        (n_samples, n_features) = K.shape
        n_kernels = int(n_features/n_samples) # We assume each kernel is a square matrix
        self.n_samples, self.n_kernels = n_samples, n_kernels

        if Lipschitz_constant==None:
            Lipschitz_constant = _load_Lipschitz_constant(K)

        tol = 10**(-6)
        coefs_current = np.zeros(n_features, dtype=np.float) # coefficients to compute
        coefs_next = np.zeros(n_features, dtype=np.float)
        Z = np.copy(coefs_next) # a linear combination of the coefficients of the 2 last iterations
        tau_1 = 1

        if self.penalty=='l11':
            prox = lambda u:prox_l11(u, self.lambda_*Lipschitz_constant)
        elif self.penalty=='l22':
            prox = lambda u:prox_l22(u, self.lambda_*Lipschitz_constant)
        elif self.penalty=='l21':
            prox = lambda u:prox_l21(u, self.lambda_*Lipschitz_constant, n_samples, n_kernels)
        elif self.penalty=='l12':
            prox = lambda u:prox_l12(u, self.lambda_*Lipschitz_constant, n_samples, n_kernels)

        if verbose==1:
            self.iteration_dual_gap = list()

        for i in range(self.n_iter):
            coefs_current = coefs_next # B_(k-1) = B_(k)
            coefs_next = prox(Z + Lipschitz_constant*next_step(y, K, Z))
            
            tau_0 = tau_1 #tau_(k+1) = tau_k
            tau_1 = (1 + sqrt(1 + 4*tau_0**2))/2

            Z = coefs_next + (tau_0 - 1)/tau_1*(coefs_next - coefs_current)
            
            # Dual problem
            objective_var = 1 - np.dot(K, coefs_next)
            objective_var = np.maximum(objective_var, 0) # Shrink
            # Primal objective function
            penalisation = self.lambda_/self.q*(mixed_norm(coefs_next,
                    self.p, self.q, n_samples, n_kernels)**self.q)
            loss = 0.5*np.sum(objective_var**2)
            objective_function = penalisation + loss

            # Dual objective function
            dual_var = objective_var
            if self.lambda_ != 0:
                dual_penalisation = dual_mixed_norm(np.dot(K.T,dual_var)/self.lambda_,
                        n_samples, n_kernels, self.penalty)
                if self.q==1:
                    # Fenchel conjugate of a mixed norm
                    if dual_penalisation > 1:
                        dual_var = dual_var / dual_penalisation
                        # If we did not normalise, dual_penalisation
                        # would be +infinity ...
                    dual_penalisation = 0
                else:
                    # Fenchel conjugate of a squared mixed norm
                    dual_penalisation = self.lambda_/2*(dual_penalisation**2)
            else:
                dual_penalisation = 0
            dual_loss = -0.5*np.sum(dual_var**2) + np.sum(dual_var)
            # trace(np.dot(duat_var[:, np.newaxis], y)) au lieu du sum(dual_var) ?
            dual_objective_function = dual_loss - self.lambda_/self.q*dual_penalisation
            gap = abs(objective_function - dual_objective_function)

            if verbose:
                sys.stderr.write("Iteration : %d\r" % i )
                # print "iteration %d" % i
                self.iteration_dual_gap.append(gap)
                if i%1000 == 0:
                    print("primal objective : %f, dual objective : %f, dual_gap : %f" % (objective_function, dual_objective_function, gap))

            if gap<=tol and i>10:
                print("convergence at iteration : %d" %i)
                break

        if verbose:
            print("dual gap : %f" % gap)
            print("objective_function : %f" % objective_function)
            print("dual_objective_function : %f" % dual_objective_function)
            print("dual_penalisation : %f" % dual_penalisation)
            print("dual_loss : %f" % dual_loss)
        self.coefs_ = coefs_next
        self.gap = gap
        self.objective_function = objective_function
        self.dual_objective_function = dual_objective_function

        return self

    def predict(self, K):
        """ Returns the prediction associated to the Kernels represented by K

        Parameters
        ----------
        K : ndarray 
            ndarray of size (n_samples, n_kernels*n_samples) representing the kernels

        Returns
        -------
        ndarray : the prediction associated to K
        """
        if self.loss=='squared-hinge':
            res = np.sign(np.dot(K, self.coefs_))
            res[res==0] = 1
            return res
        else:
            return np.dot(K, self.coefs_)

    def score(self, K, y):
        """ Returns the score prediction for the given data

        Parameters
        ----------
        K : ndarray
            matrix of observations

        y : ndarray
            the labels correspondings to K

        Returns
        -------
        The percentage of good classification for K
        """
        if self.loss=='squared-hinge':
            return np.sum(np.equal(self.predict(K), y))*100./len(y)
        else:
            print("Score not yet implemented for regression\n")
            return None

    def info(self, K, y):
        """ For test purpose

        Parameters
        ----------
        K : 2D-array
            kernels

        y : ndarray
            labels
        Returns
        -------
        A dict of informations
        """
        result = Bunch()
        n_samples, n_kernels = self.n_samples, self.n_kernels
        nulled_kernels = 0
        nulled_coefs_per_kernel = list()

        for i in self.coefs_.reshape(n_kernels, n_samples):
            if len(i[i!=0]) == 0:
                nulled_kernels = nulled_kernels + 1
            nulled_coefs_per_kernel.append(len(i[i==0]))

        result['score'] = self.score(K, y)
        result['norms'] = by_kernel_norm(self.coefs_, self.p, self.q,
                n_samples, n_kernels)
        result['nulled_coefs'] = len(self.coefs_[self.coefs_==0])
        result['nulled_kernels'] = nulled_kernels
        result['nulled_coefs_per_kernel'] = nulled_coefs_per_kernel
        result['objective_function'] = self.objective_function
        result['dual_objective_function'] = self.dual_objective_function
        result['gap'] = self.gap
        result['auc_score'] = roc_auc_score(y, self.predict(K))
        result['lambda_'] = self.lambda_
        
        return result

In [6]:
#if torch.cuda.is_available():
#    device = torch.device('cuda')
#    torch.backends.cudnn.enabled = False
#else:
device = torch.device('cpu')

In [7]:
def word2vec_padding(list_of_embeddings, length, embedding_length):
    zero_vec = np.zeros(embedding_length)
    for _ in range(length - len(list_of_embeddings)):
        list_of_embeddings.append(zero_vec)
    return list_of_embeddings[:length]


def word2vec_sum(list_of_embeddings, embedding_length):
    ret_value = np.zeros(embedding_length)
    for embedding in list_of_embeddings:
        ret_value += embedding
    return ret_value

class OneHotParser(TextParser):
    def __init__(self, vocabulary, type_func):
        """

        :param vocabulary: List of strings representing the vocabulary.
        :param type_func: Function which converts the output to the desired type, e.g. np.array.
        """
        self.vocab = vocabulary
        self.vocab_size = len(self.vocab)
        TextParser.__init__(self, type_func)

    def __call__(self, x):
        one_hot = np.zeros((len(x), self.vocab_size))  # +1 for out of vocabulary tokens.
        for idx, token_list in enumerate(x):
            sentence = ' '.join(token_list)
            vocab_idx = self.vocab.index(sentence)
            one_hot[idx, vocab_idx] = 1

        return self.convert_type(one_hot)

def load_list_from_file(file_path):
    with open(file_path) as file:
        content = file.readlines()
    ret = []
    for elem in content:
        clean_elem = elem.strip()
        if len(clean_elem) > 0:
            ret.append(clean_elem)
    return ret

In [8]:
task = 'full'
with open(os.getcwd() + '/data/zork_walkthrough_' + task + '.txt', 'rb') as f:
    data = pickle.load(f)

raw_actions = data['actions']
raw_states = data['states']

In [9]:
verbs = ['go', 'take', 'open', 'grab', 'run', 'walk', 'climb', 'kill', 'light', 'get']

#basic_actions = ['open', 'egg', 'east', 'west', 'north', 'south', 'go', 'up', 'down', 'look', 'take']
basic_actions = ['open', 'egg', 'north', 'climb', 'tree', 'take']

extended_actions = ['grab', 'run', 'climb', 'walk', 'go', 'south', 'east', 'west']

basic_objects = ['egg', 'door', 'tree', 'leaves', 'nest']

obj_ext1 = ['bag', 'bottle', 'rope', 'sword', 'lantern', 'knife', 'mat', 'mailbox',
            'rug', 'case', 'axe', 'diamond', 'leaflet', 'news', 'brick']
action_ext1 = ['enter', 'open the window', 'turn lamp on', 'move rug', 'open trap door', 'hit troll with sword']

random_words = ['bring', 'wait', 'test', 'heave', 'squat', 'garbage', 'you', 'no', 'year']

def create_actions():
    action_vocabulary = {}
    for word in dictionary:
        action_vocabulary[word] = word2vec_model[word]

    embedding_size = len(action_vocabulary['open'])
    
    return action_vocabulary, embedding_size

In [10]:
word2vec_model = glove_api.load('glove-wiki-gigaword-50')
embedding_size = word2vec_model.vector_size
word2vec_parser = Word2Vec(type_func=lambda x: torch.FloatTensor(x).to(device).unsqueeze(0),
                           word2vec_model=word2vec_model,
                           return_func=lambda x: word2vec_padding(x, 65, embedding_size))

In [11]:
fista = Fista(lambda_=0.8, loss='least-square', penalty='l11', n_iter=10000)

In [12]:
dictionary = ['pray', 'yellow', 'trapdoor', 'open', 'bell', 'touch', 'pile', 'trunk', 'sack', 'inflate', 'southeast',
              'of', 'move', 'match', 'figurine', 'railing', 'with', 'map', 'mirror', 'wind', 'examine', 'north', 'out',
              'trident', 'turn', 'skull', 'throw', 'northwest', 'case', 'bag', 'red', 'press', 'jewels', 'east', 'pump',
              'bolt', 'rusty', 'window', 'douse', 'boat', 'bracelet', 'matchbook', 'basket', 'book', 'coffin', 'bar',
              'rug', 'lid', 'drop', 'nasty', 'wrench', 'light', 'sand', 'bauble', 'kill', 'tie', 'painting', 'sword',
              'wave', 'in', 'south', 'northeast', 'ring', 'canary', 'lower', 'egg', 'all', 'to', 'candles', 'page',
              'and', 'echo', 'emerald', 'tree', 'from', 'rope', 'troll', 'screwdriver', 'torch', 'enter', 'coal', 'go',
              'look', 'shovel', 'knife', 'down', 'take', 'switch', 'prayer', 'launch', 'diamond', 'read', 'up', 'get',
              'scarab', 'west', 'land', 'southwest', 'climb', 'thief', 'raise', 'wait', 'odysseus', 'button', 'sceptre',
              'lamp', 'chalice', 'garlic', 'buoy', 'pot', 'label', 'put', 'dig', 'machine', 'close']

In [13]:
noise = MultivariateNormal(torch.zeros(50), torch.eye(50))

In [35]:
def test_accuracy(additional_prints, threshold, snr, its):
    runs = []
    for _ in range(its):
        accurate = 0
        for action in raw_actions:
            vec = 0
            for token in tokenizer(action):
                vec += word2vec_model[token]

            sampled_noise = noise.sample().numpy()
            normalized_noise = snr * np.linalg.norm(vec) * sampled_noise / np.linalg.norm(sampled_noise)            
            ground_truth = torch.Tensor(vec + normalized_noise).to(device).unsqueeze(0)

            deepcs_output = network(ground_truth, True).squeeze(0)
            list_of_words = []
            for idx in range(len(deepcs_output)):
                if deepcs_output[idx] > threshold:
                    list_of_words.append(idx)

            _, text_command = agent._select_eps_greedy_action(0, list_of_words, None)

            if set(tokenizer(action)) == set(tokenizer(text_command)):
                accurate += 1
            elif additional_prints:
                print(tokenizer(text_command))
                print(tokenizer(action))

        runs.append(accurate * 1.0 / len(raw_actions))
    return runs

In [36]:
def test_env(additional_prints, threshold, snr, its, seed=52):
    with torch.no_grad():
        runs = []
        for _ in range(its):
            obs = agent.env.reset(seed)
            reward = 0
            done = False

            idx = 0
            for action in raw_actions:
                vec = 0
                for token in tokenizer(action):
                    vec += word2vec_model[token]

                sampled_noise = noise.sample().numpy()
                normalized_noise = snr * np.linalg.norm(vec) * sampled_noise / np.linalg.norm(sampled_noise)            
                ground_truth = torch.Tensor(vec + normalized_noise).to(device).unsqueeze(0)

                deepcs_output = network(ground_truth, True).squeeze(0)
                list_of_words = []
                for idx in range(len(deepcs_output)):
                    if deepcs_output[idx] > threshold:
                        list_of_words.append(idx)

                _, text_command = agent._select_eps_greedy_action(0, list_of_words, None)

                if additional_prints:
                    agent.env.render()
                    print(text_command)
                idx += 1
                obs, rew, done, has_won = agent.env.step(text_command)
                if additional_prints:
                    print(rew)
                reward += rew
                if done:
                    break

            runs.append(int(agent.env.env.get_score()))
        return runs

# Default Agent

In [14]:
number_of_neighbors=5

action_vocabulary, embedding_size = create_actions()

In [15]:
agent = OMPDDPG(actions=action_vocabulary,
                state_parser=word2vec_parser,
                embedding_size=embedding_size,
                input_length=embedding_size,
                input_width=65,
                history_size=12,
                model_type='CNN',
                device=device,
                pomdp_mode=False,
                loss_weighting=1.0,
                linear=False,
                improved_omp=False,
                task=task)

In [16]:
agent.env.sparse_reward = False

In [17]:
path = os.getcwd() + '/deep_cs_' + task + '_cs/'

In [18]:
class MlpBow(nn.Module):
    def __init__(self, embedding_size, output_size, hidden_layers):
        super().__init__()
        self.linears = nn.ModuleList()
        self.linears.append(nn.Linear(embedding_size, hidden_layers[0]))
        for idx in range(len(hidden_layers) - 1):
            self.linears.append(nn.Linear(hidden_layers[idx], hidden_layers[idx + 1]))

        self.linears.append(nn.Linear(hidden_layers[-1], output_size))

    def forward(self, x, sigmoid=False):
        x_relu = x.view(x.size(0), -1)
        for idx in range(len(self.linears)):
            x = self.linears[idx](x_relu)
            x_relu = F.relu(x)
        if sigmoid:
            return F.sigmoid(x)
        return x

In [19]:
network = MlpBow(embedding_size, len(dictionary), [100, 100])

results for safe keeping

In [None]:
# results_accuracy = {0.3: {0.0: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 0.2: [0.9754299754299754, 0.9656019656019657, 0.9508599508599509, 0.9656019656019657, 0.9606879606879607, 0.972972972972973, 0.9680589680589681, 0.9705159705159705, 0.9631449631449631, 0.9803439803439803, 0.9582309582309583, 0.9631449631449631, 0.9631449631449631, 0.9680589680589681, 0.9705159705159705, 0.9656019656019657, 0.9582309582309583, 0.9705159705159705, 0.9656019656019657, 0.972972972972973, 0.9705159705159705, 0.9705159705159705, 0.972972972972973, 0.9606879606879607, 0.9606879606879607, 0.9754299754299754, 0.972972972972973, 0.9656019656019657, 0.9631449631449631, 0.9631449631449631, 0.9631449631449631, 0.972972972972973, 0.9705159705159705, 0.9606879606879607, 0.9778869778869779, 0.9631449631449631, 0.9557739557739557, 0.9631449631449631, 0.9656019656019657, 0.9656019656019657, 0.9533169533169533, 0.9557739557739557, 0.972972972972973, 0.9852579852579852, 0.9680589680589681, 0.9803439803439803, 0.9754299754299754, 0.9680589680589681, 0.9656019656019657, 0.9803439803439803, 0.972972972972973, 0.9705159705159705, 0.9606879606879607, 0.9631449631449631, 0.972972972972973, 0.9828009828009828, 0.9778869778869779, 0.9656019656019657, 0.9705159705159705, 0.9754299754299754, 0.9606879606879607, 0.9631449631449631, 0.9778869778869779, 0.9754299754299754, 0.9705159705159705, 0.9606879606879607, 0.9680589680589681, 0.9582309582309583, 0.9656019656019657, 0.9631449631449631, 0.9582309582309583, 0.9631449631449631, 0.9631449631449631, 0.9680589680589681, 0.9582309582309583, 0.9705159705159705, 0.9606879606879607, 0.9606879606879607, 0.9754299754299754, 0.9852579852579852], 0.4: [0.8132678132678133, 0.7518427518427518, 0.7395577395577395, 0.7936117936117936, 0.7272727272727273, 0.7321867321867321, 0.7665847665847666, 0.7346437346437347, 0.7518427518427518, 0.7223587223587223, 0.769041769041769, 0.7641277641277642, 0.7714987714987716, 0.7395577395577395, 0.7592137592137592, 0.7936117936117936, 0.7813267813267813, 0.7714987714987716, 0.7788697788697788, 0.7641277641277642, 0.773955773955774, 0.7518427518427518, 0.7469287469287469, 0.7641277641277642, 0.7764127764127764, 0.7911547911547911, 0.773955773955774, 0.7371007371007371, 0.7518427518427518, 0.7567567567567568, 0.7542997542997543, 0.7764127764127764, 0.773955773955774, 0.7444717444717445, 0.7272727272727273, 0.769041769041769, 0.7371007371007371, 0.7911547911547911, 0.7542997542997543, 0.7641277641277642, 0.7985257985257985, 0.8058968058968059, 0.7542997542997543, 0.7542997542997543, 0.8034398034398035, 0.773955773955774, 0.7788697788697788, 0.773955773955774, 0.8108108108108109, 0.7518427518427518, 0.7788697788697788, 0.8329238329238329, 0.7542997542997543, 0.7862407862407862, 0.7641277641277642, 0.7837837837837838, 0.7518427518427518, 0.7960687960687961, 0.7567567567567568, 0.8083538083538083, 0.7714987714987716, 0.7837837837837838, 0.7567567567567568, 0.714987714987715, 0.7641277641277642, 0.7199017199017199, 0.7714987714987716, 0.7395577395577395, 0.7567567567567568, 0.7321867321867321, 0.7518427518427518, 0.7297297297297297, 0.7764127764127764, 0.7542997542997543, 0.7444717444717445, 0.773955773955774, 0.7297297297297297, 0.7321867321867321, 0.7444717444717445, 0.7788697788697788], 0.6: [0.5233415233415234, 0.5036855036855037, 0.515970515970516, 0.542997542997543, 0.5282555282555282, 0.5085995085995086, 0.5135135135135135, 0.5208845208845209, 0.5307125307125307, 0.47665847665847666, 0.5135135135135135, 0.5208845208845209, 0.5405405405405406, 0.5626535626535627, 0.5282555282555282, 0.538083538083538, 0.515970515970516, 0.5036855036855037, 0.5307125307125307, 0.5233415233415234, 0.49385749385749383, 0.47174447174447176, 0.5331695331695332, 0.5331695331695332, 0.5257985257985258, 0.5110565110565111, 0.515970515970516, 0.5110565110565111, 0.542997542997543, 0.5331695331695332, 0.5282555282555282, 0.547911547911548, 0.457002457002457, 0.515970515970516, 0.48894348894348894, 0.4914004914004914, 0.5282555282555282, 0.4963144963144963, 0.4594594594594595, 0.547911547911548, 0.48894348894348894, 0.5208845208845209, 0.5307125307125307, 0.5700245700245701, 0.515970515970516, 0.547911547911548, 0.542997542997543, 0.5307125307125307, 0.5356265356265356, 0.515970515970516, 0.5503685503685504, 0.5601965601965602, 0.5184275184275184, 0.5307125307125307, 0.4914004914004914, 0.5724815724815725, 0.4742014742014742, 0.5773955773955773, 0.5282555282555282, 0.5331695331695332, 0.47911547911547914, 0.515970515970516, 0.4619164619164619, 0.4692874692874693, 0.5135135135135135, 0.4692874692874693, 0.5135135135135135, 0.542997542997543, 0.48402948402948404, 0.48894348894348894, 0.4668304668304668, 0.44717444717444715, 0.5012285012285013, 0.48157248157248156, 0.538083538083538, 0.5012285012285013, 0.48402948402948404, 0.515970515970516, 0.5085995085995086, 0.4643734643734644], 0.8: [0.35626535626535627, 0.3832923832923833, 0.371007371007371, 0.36855036855036855, 0.39803439803439805, 0.3759213759213759, 0.3955773955773956, 0.3832923832923833, 0.3046683046683047, 0.3366093366093366, 0.3832923832923833, 0.36609336609336607, 0.3046683046683047, 0.35626535626535627, 0.371007371007371, 0.31695331695331697, 0.3538083538083538, 0.36363636363636365, 0.32186732186732187, 0.3366093366093366, 0.3464373464373464, 0.3366093366093366, 0.3415233415233415, 0.35135135135135137, 0.35872235872235875, 0.3955773955773956, 0.3857493857493858, 0.32923832923832924, 0.36609336609336607, 0.3316953316953317, 0.343980343980344, 0.35872235872235875, 0.35135135135135137, 0.32678132678132676, 0.36855036855036855, 0.3464373464373464, 0.33906633906633904, 0.3316953316953317, 0.37346437346437344, 0.3488943488943489, 0.36363636363636365, 0.35872235872235875, 0.36609336609336607, 0.3538083538083538, 0.36363636363636365, 0.35626535626535627, 0.343980343980344, 0.3415233415233415, 0.35872235872235875, 0.371007371007371, 0.35626535626535627, 0.3366093366093366, 0.3783783783783784, 0.35626535626535627, 0.33415233415233414, 0.343980343980344, 0.3783783783783784, 0.36609336609336607, 0.3538083538083538, 0.35626535626535627, 0.29975429975429974, 0.3488943488943489, 0.33906633906633904, 0.32923832923832924, 0.3464373464373464, 0.36855036855036855, 0.3808353808353808, 0.3316953316953317, 0.3095823095823096, 0.33415233415233414, 0.3415233415233415, 0.3488943488943489, 0.28746928746928746, 0.3194103194103194, 0.3046683046683047, 0.32923832923832924, 0.35135135135135137, 0.28746928746928746, 0.31695331695331697, 0.25552825552825553], 1.0: [0.23587223587223588, 0.2285012285012285, 0.2628992628992629, 0.2678132678132678, 0.2628992628992629, 0.23587223587223588, 0.2727272727272727, 0.25061425061425063, 0.23587223587223588, 0.24324324324324326, 0.22358722358722358, 0.25061425061425063, 0.2678132678132678, 0.24324324324324326, 0.24815724815724816, 0.23587223587223588, 0.2628992628992629, 0.2800982800982801, 0.2457002457002457, 0.2727272727272727, 0.2457002457002457, 0.26535626535626533, 0.24324324324324326, 0.2727272727272727, 0.26044226044226043, 0.23095823095823095, 0.2800982800982801, 0.23095823095823095, 0.23587223587223588, 0.25061425061425063, 0.22604422604422605, 0.24815724815724816, 0.24324324324324326, 0.24324324324324326, 0.24324324324324326, 0.2457002457002457, 0.2334152334152334, 0.22113022113022113, 0.22358722358722358, 0.22358722358722358, 0.28501228501228504, 0.23587223587223588, 0.25552825552825553, 0.257985257985258, 0.2334152334152334, 0.25307125307125306, 0.26044226044226043, 0.24078624078624078, 0.28746928746928746, 0.28255528255528256, 0.2628992628992629, 0.26044226044226043, 0.2727272727272727, 0.25307125307125306, 0.25307125307125306, 0.25307125307125306, 0.21621621621621623, 0.2727272727272727, 0.26044226044226043, 0.24078624078624078, 0.22358722358722358, 0.22604422604422605, 0.2457002457002457, 0.2457002457002457, 0.23587223587223588, 0.25061425061425063, 0.23587223587223588, 0.25552825552825553, 0.20638820638820637, 0.21375921375921375, 0.21867321867321868, 0.21867321867321868, 0.21375921375921375, 0.22604422604422605, 0.21867321867321868, 0.21867321867321868, 0.22604422604422605, 0.22358722358722358, 0.19656019656019655, 0.20393120393120392]}, 0.5: {0.0: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 0.2: [0.9631449631449631, 0.9606879606879607, 0.9754299754299754, 0.9606879606879607, 0.9680589680589681, 0.9705159705159705, 0.9877149877149877, 0.9680589680589681, 0.9803439803439803, 0.9680589680589681, 0.9803439803439803, 0.9680589680589681, 0.9754299754299754, 0.9705159705159705, 0.9852579852579852, 0.9852579852579852, 0.9926289926289926, 0.972972972972973, 0.9680589680589681, 0.9606879606879607, 0.9803439803439803, 0.9680589680589681, 0.9656019656019657, 0.972972972972973, 0.972972972972973, 0.9778869778869779, 0.9705159705159705, 0.9656019656019657, 0.9705159705159705, 0.9582309582309583, 0.9705159705159705, 0.9754299754299754, 0.9803439803439803, 0.9557739557739557, 0.972972972972973, 0.972972972972973, 0.9754299754299754, 0.9631449631449631, 0.9754299754299754, 0.9680589680589681, 0.9778869778869779, 0.9631449631449631, 0.9656019656019657, 0.9582309582309583, 0.9656019656019657, 0.9754299754299754, 0.9680589680589681, 0.9778869778869779, 0.972972972972973, 0.9778869778869779, 0.9778869778869779, 0.9606879606879607, 0.9606879606879607, 0.9754299754299754, 0.9705159705159705, 0.9705159705159705, 0.9803439803439803, 0.9705159705159705, 0.9705159705159705, 0.9778869778869779, 0.972972972972973, 0.9852579852579852, 0.972972972972973, 0.9582309582309583, 0.9778869778869779, 0.9582309582309583, 0.9606879606879607, 0.9705159705159705, 0.9680589680589681, 0.9754299754299754, 0.972972972972973, 0.9656019656019657, 0.9631449631449631, 0.9754299754299754, 0.9582309582309583, 0.9606879606879607, 0.9631449631449631, 0.9582309582309583, 0.9656019656019657, 0.9606879606879607], 0.4: [0.773955773955774, 0.7395577395577395, 0.7714987714987716, 0.7641277641277642, 0.7641277641277642, 0.7469287469287469, 0.7764127764127764, 0.769041769041769, 0.7518427518427518, 0.7764127764127764, 0.7641277641277642, 0.8034398034398035, 0.7272727272727273, 0.742014742014742, 0.7788697788697788, 0.742014742014742, 0.742014742014742, 0.7567567567567568, 0.7469287469287469, 0.7641277641277642, 0.769041769041769, 0.7493857493857494, 0.7542997542997543, 0.769041769041769, 0.7714987714987716, 0.7567567567567568, 0.7469287469287469, 0.7837837837837838, 0.7444717444717445, 0.7813267813267813, 0.7862407862407862, 0.7592137592137592, 0.7592137592137592, 0.7911547911547911, 0.7665847665847666, 0.7567567567567568, 0.7764127764127764, 0.7862407862407862, 0.7886977886977887, 0.7371007371007371, 0.773955773955774, 0.773955773955774, 0.7813267813267813, 0.7493857493857494, 0.769041769041769, 0.773955773955774, 0.7788697788697788, 0.7469287469287469, 0.7714987714987716, 0.8108108108108109, 0.7837837837837838, 0.7493857493857494, 0.7616707616707616, 0.7714987714987716, 0.7371007371007371, 0.7837837837837838, 0.7493857493857494, 0.8108108108108109, 0.769041769041769, 0.7837837837837838, 0.7493857493857494, 0.7444717444717445, 0.7321867321867321, 0.7592137592137592, 0.742014742014742, 0.7567567567567568, 0.769041769041769, 0.7616707616707616, 0.7518427518427518, 0.7297297297297297, 0.7174447174447175, 0.7542997542997543, 0.7346437346437347, 0.7469287469287469, 0.7100737100737101, 0.7518427518427518, 0.7469287469287469, 0.7100737100737101, 0.7248157248157249, 0.7223587223587223], 0.6: [0.5454545454545454, 0.5282555282555282, 0.5085995085995086, 0.5626535626535627, 0.5331695331695332, 0.5528255528255528, 0.5061425061425061, 0.5036855036855037, 0.538083538083538, 0.5307125307125307, 0.5184275184275184, 0.538083538083538, 0.5061425061425061, 0.5012285012285013, 0.5307125307125307, 0.5085995085995086, 0.4963144963144963, 0.515970515970516, 0.5552825552825553, 0.48402948402948404, 0.5135135135135135, 0.5528255528255528, 0.5233415233415234, 0.538083538083538, 0.5012285012285013, 0.5257985257985258, 0.5503685503685504, 0.5085995085995086, 0.5454545454545454, 0.5012285012285013, 0.547911547911548, 0.5233415233415234, 0.5528255528255528, 0.4987714987714988, 0.5405405405405406, 0.5528255528255528, 0.47665847665847666, 0.5823095823095823, 0.48894348894348894, 0.5184275184275184, 0.5184275184275184, 0.5405405405405406, 0.5577395577395577, 0.5626535626535627, 0.47665847665847666, 0.5773955773955773, 0.5454545454545454, 0.5184275184275184, 0.5405405405405406, 0.515970515970516, 0.5503685503685504, 0.5061425061425061, 0.5454545454545454, 0.47911547911547914, 0.5282555282555282, 0.5282555282555282, 0.5331695331695332, 0.5184275184275184, 0.5085995085995086, 0.5012285012285013, 0.5085995085995086, 0.5208845208845209, 0.47174447174447176, 0.4987714987714988, 0.4643734643734644, 0.4963144963144963, 0.5085995085995086, 0.4914004914004914, 0.4643734643734644, 0.4864864864864865, 0.5208845208845209, 0.538083538083538, 0.5233415233415234, 0.48157248157248156, 0.515970515970516, 0.5110565110565111, 0.49385749385749383, 0.5110565110565111, 0.5061425061425061, 0.515970515970516], 0.8: [0.343980343980344, 0.3808353808353808, 0.32923832923832924, 0.32678132678132676, 0.3366093366093366, 0.3783783783783784, 0.35135135135135137, 0.3857493857493858, 0.3464373464373464, 0.36117936117936117, 0.3759213759213759, 0.35626535626535627, 0.3464373464373464, 0.37346437346437344, 0.3538083538083538, 0.3488943488943489, 0.3194103194103194, 0.35626535626535627, 0.36855036855036855, 0.41277641277641275, 0.3488943488943489, 0.3488943488943489, 0.3464373464373464, 0.3366093366093366, 0.36609336609336607, 0.35872235872235875, 0.36363636363636365, 0.35135135135135137, 0.4004914004914005, 0.3316953316953317, 0.4004914004914005, 0.3857493857493858, 0.35135135135135137, 0.35626535626535627, 0.36117936117936117, 0.3194103194103194, 0.32186732186732187, 0.371007371007371, 0.3095823095823096, 0.4103194103194103, 0.3415233415233415, 0.3488943488943489, 0.3464373464373464, 0.36117936117936117, 0.3464373464373464, 0.36609336609336607, 0.3808353808353808, 0.31695331695331697, 0.37346437346437344, 0.36855036855036855, 0.343980343980344, 0.371007371007371, 0.3783783783783784, 0.3488943488943489, 0.32923832923832924, 0.3464373464373464, 0.3832923832923833, 0.3415233415233415, 0.3464373464373464, 0.36609336609336607, 0.3071253071253071, 0.3366093366093366, 0.3316953316953317, 0.33906633906633904, 0.343980343980344, 0.32923832923832924, 0.36117936117936117, 0.27764127764127766, 0.3316953316953317, 0.32678132678132676, 0.32432432432432434, 0.3415233415233415, 0.32432432432432434, 0.3071253071253071, 0.33415233415233414, 0.3144963144963145, 0.371007371007371, 0.3194103194103194, 0.343980343980344, 0.3488943488943489], 1.0: [0.2113022113022113, 0.23587223587223588, 0.2727272727272727, 0.257985257985258, 0.21375921375921375, 0.28255528255528256, 0.2334152334152334, 0.257985257985258, 0.257985257985258, 0.24078624078624078, 0.23587223587223588, 0.25061425061425063, 0.26535626535626533, 0.2334152334152334, 0.25552825552825553, 0.2678132678132678, 0.28255528255528256, 0.26044226044226043, 0.22113022113022113, 0.28746928746928746, 0.2678132678132678, 0.24815724815724816, 0.2628992628992629, 0.2628992628992629, 0.2702702702702703, 0.23095823095823095, 0.2334152334152334, 0.257985257985258, 0.26044226044226043, 0.20884520884520885, 0.22113022113022113, 0.2334152334152334, 0.24324324324324326, 0.25061425061425063, 0.23587223587223588, 0.25307125307125306, 0.21621621621621623, 0.2702702702702703, 0.23587223587223588, 0.23587223587223588, 0.23832923832923833, 0.25307125307125306, 0.25552825552825553, 0.26044226044226043, 0.24815724815724816, 0.24815724815724816, 0.20147420147420148, 0.21867321867321868, 0.25061425061425063, 0.257985257985258, 0.25307125307125306, 0.2678132678132678, 0.257985257985258, 0.2800982800982801, 0.22113022113022113, 0.25061425061425063, 0.2727272727272727, 0.257985257985258, 0.25061425061425063, 0.24324324324324326, 0.22358722358722358, 0.24078624078624078, 0.21867321867321868, 0.24324324324324326, 0.26535626535626533, 0.25552825552825553, 0.20884520884520885, 0.25061425061425063, 0.23832923832923833, 0.22358722358722358, 0.26044226044226043, 0.2678132678132678, 0.24815724815724816, 0.23832923832923833, 0.19656019656019655, 0.2457002457002457, 0.24324324324324326, 0.25552825552825553, 0.22358722358722358, 0.2457002457002457]}, 0.7: {0.0: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 0.2: [0.9680589680589681, 0.9557739557739557, 0.9631449631449631, 0.9705159705159705, 0.9631449631449631, 0.9606879606879607, 0.9680589680589681, 0.9778869778869779, 0.9680589680589681, 0.9533169533169533, 0.9705159705159705, 0.9606879606879607, 0.9631449631449631, 0.9631449631449631, 0.9631449631449631, 0.9852579852579852, 0.972972972972973, 0.9680589680589681, 0.972972972972973, 0.9778869778869779, 0.9705159705159705, 0.9680589680589681, 0.9656019656019657, 0.9582309582309583, 0.9557739557739557, 0.972972972972973, 0.9656019656019657, 0.9606879606879607, 0.972972972972973, 0.9754299754299754, 0.9533169533169533, 0.9582309582309583, 0.972972972972973, 0.972972972972973, 0.9631449631449631, 0.9680589680589681, 0.9803439803439803, 0.9582309582309583, 0.9778869778869779, 0.9705159705159705, 0.9680589680589681, 0.9582309582309583, 0.9582309582309583, 0.9778869778869779, 0.9705159705159705, 0.9680589680589681, 0.9582309582309583, 0.9705159705159705, 0.972972972972973, 0.972972972972973, 0.9778869778869779, 0.9803439803439803, 0.9705159705159705, 0.972972972972973, 0.9828009828009828, 0.9705159705159705, 0.9705159705159705, 0.9557739557739557, 0.972972972972973, 0.9582309582309583, 0.9803439803439803, 0.9631449631449631, 0.9631449631449631, 0.9533169533169533, 0.9606879606879607, 0.9582309582309583, 0.9557739557739557, 0.9508599508599509, 0.9606879606879607, 0.9631449631449631, 0.9557739557739557, 0.9656019656019657, 0.9434889434889435, 0.9631449631449631, 0.9606879606879607, 0.9656019656019657, 0.9778869778869779, 0.972972972972973, 0.972972972972973, 0.9606879606879607], 0.4: [0.7518427518427518, 0.7592137592137592, 0.7469287469287469, 0.769041769041769, 0.7764127764127764, 0.7592137592137592, 0.7248157248157249, 0.7641277641277642, 0.7936117936117936, 0.7518427518427518, 0.7764127764127764, 0.7346437346437347, 0.7567567567567568, 0.773955773955774, 0.7542997542997543, 0.7395577395577395, 0.769041769041769, 0.773955773955774, 0.7788697788697788, 0.7960687960687961, 0.7567567567567568, 0.773955773955774, 0.7444717444717445, 0.7567567567567568, 0.7469287469287469, 0.769041769041769, 0.7714987714987716, 0.7469287469287469, 0.7542997542997543, 0.7837837837837838, 0.7321867321867321, 0.7788697788697788, 0.7714987714987716, 0.7862407862407862, 0.7592137592137592, 0.7641277641277642, 0.7542997542997543, 0.7665847665847666, 0.7444717444717445, 0.7764127764127764, 0.7567567567567568, 0.7444717444717445, 0.7665847665847666, 0.7764127764127764, 0.7665847665847666, 0.8034398034398035, 0.7518427518427518, 0.8034398034398035, 0.7764127764127764, 0.769041769041769, 0.7886977886977887, 0.7764127764127764, 0.7936117936117936, 0.7936117936117936, 0.7714987714987716, 0.7542997542997543, 0.7862407862407862, 0.7641277641277642, 0.769041769041769, 0.7788697788697788, 0.7223587223587223, 0.7174447174447175, 0.7518427518427518, 0.7518427518427518, 0.7321867321867321, 0.7592137592137592, 0.7321867321867321, 0.7444717444717445, 0.7346437346437347, 0.7297297297297297, 0.7641277641277642, 0.7272727272727273, 0.7100737100737101, 0.7616707616707616, 0.7395577395577395, 0.7321867321867321, 0.7469287469287469, 0.7297297297297297, 0.7444717444717445, 0.7444717444717445], 0.6: [0.5110565110565111, 0.5307125307125307, 0.5528255528255528, 0.5626535626535627, 0.515970515970516, 0.5503685503685504, 0.5577395577395577, 0.5233415233415234, 0.5257985257985258, 0.5233415233415234, 0.515970515970516, 0.4864864864864865, 0.5110565110565111, 0.5749385749385749, 0.538083538083538, 0.5233415233415234, 0.5135135135135135, 0.5282555282555282, 0.4963144963144963, 0.547911547911548, 0.5577395577395577, 0.5208845208845209, 0.5405405405405406, 0.5208845208845209, 0.538083538083538, 0.542997542997543, 0.5331695331695332, 0.5233415233415234, 0.5503685503685504, 0.5405405405405406, 0.5233415233415234, 0.515970515970516, 0.5282555282555282, 0.5233415233415234, 0.5282555282555282, 0.6044226044226044, 0.5233415233415234, 0.5233415233415234, 0.5135135135135135, 0.5307125307125307, 0.5773955773955773, 0.5405405405405406, 0.5749385749385749, 0.5061425061425061, 0.5823095823095823, 0.49385749385749383, 0.542997542997543, 0.5331695331695332, 0.5552825552825553, 0.5405405405405406, 0.5307125307125307, 0.5307125307125307, 0.5503685503685504, 0.5724815724815725, 0.5528255528255528, 0.5307125307125307, 0.5626535626535627, 0.5552825552825553, 0.5307125307125307, 0.5331695331695332, 0.4987714987714988, 0.4643734643734644, 0.5135135135135135, 0.5036855036855037, 0.5208845208845209, 0.4987714987714988, 0.5012285012285013, 0.49385749385749383, 0.5282555282555282, 0.4668304668304668, 0.4742014742014742, 0.5135135135135135, 0.4668304668304668, 0.5208845208845209, 0.5061425061425061, 0.4864864864864865, 0.47665847665847666, 0.5233415233415234, 0.47911547911547914, 0.47911547911547914], 0.8: [0.3783783783783784, 0.36609336609336607, 0.35626535626535627, 0.3906633906633907, 0.3955773955773956, 0.3857493857493858, 0.3464373464373464, 0.33415233415233414, 0.371007371007371, 0.3488943488943489, 0.3488943488943489, 0.35135135135135137, 0.35135135135135137, 0.31695331695331697, 0.3906633906633907, 0.36855036855036855, 0.35872235872235875, 0.35872235872235875, 0.4004914004914005, 0.36609336609336607, 0.40786240786240785, 0.35872235872235875, 0.35135135135135137, 0.3906633906633907, 0.33906633906633904, 0.32923832923832924, 0.36363636363636365, 0.36363636363636365, 0.36363636363636365, 0.3783783783783784, 0.35135135135135137, 0.3906633906633907, 0.3857493857493858, 0.31695331695331697, 0.36609336609336607, 0.3808353808353808, 0.35135135135135137, 0.3488943488943489, 0.3783783783783784, 0.32186732186732187, 0.35872235872235875, 0.39803439803439805, 0.3538083538083538, 0.3759213759213759, 0.3783783783783784, 0.3783783783783784, 0.3906633906633907, 0.3783783783783784, 0.3955773955773956, 0.3955773955773956, 0.36117936117936117, 0.36363636363636365, 0.32923832923832924, 0.36609336609336607, 0.32923832923832924, 0.33415233415233414, 0.39803439803439805, 0.37346437346437344, 0.33415233415233414, 0.3464373464373464, 0.3316953316953317, 0.35872235872235875, 0.32186732186732187, 0.3832923832923833, 0.3538083538083538, 0.33906633906633904, 0.3488943488943489, 0.3759213759213759, 0.31203931203931207, 0.28992628992628994, 0.3194103194103194, 0.33415233415233414, 0.32678132678132676, 0.31695331695331697, 0.3095823095823096, 0.3366093366093366, 0.36117936117936117, 0.37346437346437344, 0.3415233415233415, 0.32923832923832924], 1.0: [0.2628992628992629, 0.2628992628992629, 0.2628992628992629, 0.2702702702702703, 0.26044226044226043, 0.2457002457002457, 0.257985257985258, 0.28992628992628994, 0.24078624078624078, 0.2334152334152334, 0.24324324324324326, 0.24815724815724816, 0.22604422604422605, 0.28992628992628994, 0.25307125307125306, 0.257985257985258, 0.257985257985258, 0.3071253071253071, 0.2702702702702703, 0.2285012285012285, 0.24815724815724816, 0.26044226044226043, 0.2727272727272727, 0.20393120393120392, 0.2628992628992629, 0.28746928746928746, 0.26535626535626533, 0.2727272727272727, 0.28501228501228504, 0.25061425061425063, 0.26535626535626533, 0.257985257985258, 0.2702702702702703, 0.257985257985258, 0.27764127764127766, 0.31695331695331697, 0.25307125307125306, 0.2334152334152334, 0.23832923832923833, 0.22604422604422605, 0.29975429975429974, 0.2727272727272727, 0.25307125307125306, 0.28746928746928746, 0.26535626535626533, 0.24324324324324326, 0.27764127764127766, 0.2751842751842752, 0.26044226044226043, 0.2751842751842752, 0.28992628992628994, 0.2800982800982801, 0.25307125307125306, 0.2727272727272727, 0.23095823095823095, 0.24078624078624078, 0.25552825552825553, 0.2800982800982801, 0.2727272727272727, 0.2800982800982801, 0.22604422604422605, 0.21867321867321868, 0.26044226044226043, 0.22604422604422605, 0.25552825552825553, 0.21867321867321868, 0.23095823095823095, 0.23095823095823095, 0.2113022113022113, 0.24078624078624078, 0.2334152334152334, 0.25061425061425063, 0.2457002457002457, 0.2702702702702703, 0.23832923832923833, 0.26044226044226043, 0.257985257985258, 0.23095823095823095, 0.24078624078624078, 0.2727272727272727]}, 0.9: {0.0: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 0.2: [0.941031941031941, 0.9557739557739557, 0.9336609336609336, 0.9533169533169533, 0.9606879606879607, 0.9582309582309583, 0.9533169533169533, 0.9459459459459459, 0.9606879606879607, 0.9582309582309583, 0.9680589680589681, 0.9533169533169533, 0.9705159705159705, 0.9508599508599509, 0.9484029484029484, 0.9656019656019657, 0.9533169533169533, 0.9533169533169533, 0.9484029484029484, 0.9705159705159705, 0.9582309582309583, 0.9533169533169533, 0.9582309582309583, 0.9582309582309583, 0.9459459459459459, 0.9508599508599509, 0.941031941031941, 0.9557739557739557, 0.9631449631449631, 0.9557739557739557, 0.9606879606879607, 0.9778869778869779, 0.9533169533169533, 0.9508599508599509, 0.9557739557739557, 0.941031941031941, 0.9557739557739557, 0.9533169533169533, 0.9434889434889435, 0.9656019656019657, 0.9705159705159705, 0.9434889434889435, 0.9582309582309583, 0.9582309582309583, 0.9459459459459459, 0.9582309582309583, 0.9631449631449631, 0.9754299754299754, 0.9606879606879607, 0.9533169533169533, 0.9606879606879607, 0.9631449631449631, 0.9606879606879607, 0.9557739557739557, 0.9754299754299754, 0.9557739557739557, 0.9606879606879607, 0.9656019656019657, 0.9656019656019657, 0.941031941031941, 0.9459459459459459, 0.9484029484029484, 0.9459459459459459, 0.9312039312039312, 0.9508599508599509, 0.9459459459459459, 0.9606879606879607, 0.9336609336609336, 0.9508599508599509, 0.9459459459459459, 0.9582309582309583, 0.9582309582309583, 0.9459459459459459, 0.9533169533169533, 0.9508599508599509, 0.9754299754299754, 0.9434889434889435, 0.9533169533169533, 0.9557739557739557, 0.9508599508599509], 0.4: [0.7444717444717445, 0.7076167076167076, 0.773955773955774, 0.7321867321867321, 0.7567567567567568, 0.7616707616707616, 0.7592137592137592, 0.7567567567567568, 0.7567567567567568, 0.7616707616707616, 0.7223587223587223, 0.7199017199017199, 0.7125307125307125, 0.7297297297297297, 0.7346437346437347, 0.7395577395577395, 0.7297297297297297, 0.7174447174447175, 0.7321867321867321, 0.7174447174447175, 0.7395577395577395, 0.7051597051597052, 0.7469287469287469, 0.6928746928746928, 0.7248157248157249, 0.7395577395577395, 0.6928746928746928, 0.7051597051597052, 0.7174447174447175, 0.7567567567567568, 0.742014742014742, 0.7371007371007371, 0.7174447174447175, 0.7321867321867321, 0.7100737100737101, 0.742014742014742, 0.742014742014742, 0.7125307125307125, 0.7297297297297297, 0.7248157248157249, 0.7223587223587223, 0.7395577395577395, 0.7395577395577395, 0.7567567567567568, 0.742014742014742, 0.7469287469287469, 0.7321867321867321, 0.7567567567567568, 0.7616707616707616, 0.7346437346437347, 0.7493857493857494, 0.7592137592137592, 0.7542997542997543, 0.7321867321867321, 0.7297297297297297, 0.7493857493857494, 0.7297297297297297, 0.7518427518427518, 0.7346437346437347, 0.7321867321867321, 0.7321867321867321, 0.7100737100737101, 0.7125307125307125, 0.742014742014742, 0.7125307125307125, 0.7125307125307125, 0.7321867321867321, 0.6855036855036855, 0.7469287469287469, 0.7444717444717445, 0.7272727272727273, 0.7125307125307125, 0.7174447174447175, 0.7100737100737101, 0.6928746928746928, 0.7125307125307125, 0.7027027027027027, 0.7223587223587223, 0.7051597051597052, 0.7346437346437347], 0.6: [0.4668304668304668, 0.4987714987714988, 0.48894348894348894, 0.5356265356265356, 0.5356265356265356, 0.4987714987714988, 0.5282555282555282, 0.5233415233415234, 0.5528255528255528, 0.5135135135135135, 0.5184275184275184, 0.5233415233415234, 0.5036855036855037, 0.4864864864864865, 0.5012285012285013, 0.5036855036855037, 0.4963144963144963, 0.5331695331695332, 0.5307125307125307, 0.515970515970516, 0.48894348894348894, 0.5208845208845209, 0.5135135135135135, 0.5233415233415234, 0.5036855036855037, 0.542997542997543, 0.5208845208845209, 0.5208845208845209, 0.542997542997543, 0.4963144963144963, 0.5307125307125307, 0.5307125307125307, 0.5184275184275184, 0.5528255528255528, 0.515970515970516, 0.5282555282555282, 0.5184275184275184, 0.5036855036855037, 0.5036855036855037, 0.5282555282555282, 0.5282555282555282, 0.5282555282555282, 0.5208845208845209, 0.5036855036855037, 0.5454545454545454, 0.5503685503685504, 0.5528255528255528, 0.5012285012285013, 0.5085995085995086, 0.5331695331695332, 0.542997542997543, 0.4864864864864865, 0.5257985257985258, 0.5307125307125307, 0.5405405405405406, 0.5601965601965602, 0.5257985257985258, 0.5331695331695332, 0.5135135135135135, 0.542997542997543, 0.44471744471744473, 0.4963144963144963, 0.5110565110565111, 0.48402948402948404, 0.4963144963144963, 0.5012285012285013, 0.5110565110565111, 0.5061425061425061, 0.5061425061425061, 0.48157248157248156, 0.4987714987714988, 0.4864864864864865, 0.4864864864864865, 0.5110565110565111, 0.5454545454545454, 0.5036855036855037, 0.47174447174447176, 0.48894348894348894, 0.4619164619164619, 0.48402948402948404], 0.8: [0.343980343980344, 0.32923832923832924, 0.3316953316953317, 0.343980343980344, 0.36855036855036855, 0.37346437346437344, 0.32678132678132676, 0.4103194103194103, 0.35626535626535627, 0.36363636363636365, 0.36855036855036855, 0.3538083538083538, 0.3808353808353808, 0.3906633906633907, 0.32678132678132676, 0.35626535626535627, 0.3906633906633907, 0.37346437346437344, 0.3759213759213759, 0.36363636363636365, 0.3538083538083538, 0.35872235872235875, 0.3366093366093366, 0.343980343980344, 0.3808353808353808, 0.3366093366093366, 0.3955773955773956, 0.36117936117936117, 0.3538083538083538, 0.3488943488943489, 0.37346437346437344, 0.343980343980344, 0.36363636363636365, 0.371007371007371, 0.33415233415233414, 0.32186732186732187, 0.3808353808353808, 0.36363636363636365, 0.32678132678132676, 0.3857493857493858, 0.40786240786240785, 0.36609336609336607, 0.3316953316953317, 0.33906633906633904, 0.3415233415233415, 0.35872235872235875, 0.35872235872235875, 0.40786240786240785, 0.343980343980344, 0.3538083538083538, 0.3783783783783784, 0.3488943488943489, 0.35872235872235875, 0.35872235872235875, 0.36117936117936117, 0.37346437346437344, 0.371007371007371, 0.3366093366093366, 0.40786240786240785, 0.39803439803439805, 0.343980343980344, 0.36117936117936117, 0.35872235872235875, 0.37346437346437344, 0.32678132678132676, 0.3366093366093366, 0.33906633906633904, 0.3144963144963145, 0.3464373464373464, 0.3194103194103194, 0.3366093366093366, 0.3144963144963145, 0.343980343980344, 0.3488943488943489, 0.3759213759213759, 0.31695331695331697, 0.3464373464373464, 0.3759213759213759, 0.28746928746928746, 0.3488943488943489], 1.0: [0.2727272727272727, 0.25307125307125306, 0.25061425061425063, 0.25307125307125306, 0.26044226044226043, 0.2751842751842752, 0.2751842751842752, 0.26044226044226043, 0.2678132678132678, 0.29975429975429974, 0.2751842751842752, 0.25061425061425063, 0.2457002457002457, 0.2628992628992629, 0.2727272727272727, 0.27764127764127766, 0.2678132678132678, 0.2678132678132678, 0.2678132678132678, 0.31203931203931207, 0.2628992628992629, 0.26044226044226043, 0.2457002457002457, 0.29484029484029484, 0.26535626535626533, 0.2285012285012285, 0.21867321867321868, 0.2702702702702703, 0.2628992628992629, 0.257985257985258, 0.24815724815724816, 0.25552825552825553, 0.23832923832923833, 0.24815724815724816, 0.24815724815724816, 0.2457002457002457, 0.22113022113022113, 0.20147420147420148, 0.257985257985258, 0.2972972972972973, 0.32186732186732187, 0.27764127764127766, 0.2702702702702703, 0.24815724815724816, 0.29975429975429974, 0.2334152334152334, 0.2628992628992629, 0.257985257985258, 0.2702702702702703, 0.26044226044226043, 0.2285012285012285, 0.29238329238329236, 0.2457002457002457, 0.27764127764127766, 0.2972972972972973, 0.2628992628992629, 0.26044226044226043, 0.2727272727272727, 0.2457002457002457, 0.25061425061425063, 0.2702702702702703, 0.2113022113022113, 0.2334152334152334, 0.28255528255528256, 0.25061425061425063, 0.2678132678132678, 0.21621621621621623, 0.25307125307125306, 0.27764127764127766, 0.2727272727272727, 0.2751842751842752, 0.23587223587223588, 0.25307125307125306, 0.23832923832923833, 0.2285012285012285, 0.24324324324324326, 0.23587223587223588, 0.23832923832923833, 0.21867321867321868, 0.257985257985258]}}

In [None]:
# results_reward = {0.3: {0.0: [350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350], 0.2: [92, 35, 156, 35, 147, 35, 167, 117, 35, 30, 35, 35, 63, 142, 30, 35, 83, 30, 67, 35, 63, 35, 35, 53, 30, 53, 30, 35, 35, 35, 30, 53, 117, 30, 35, 35, 35, 30, 35, 35, 63, 63, 53, 35, 72, 30, 35, 53, 30, 35, 30, 35, 131, 35, 5, 35, 57, 63, 172, 121, 35, 35, 116, 182, 35, 53, 136, 102, 35, 53, 35, 53, 35, 63, 5, 111, 63, 53, 35, 35], 0.4: [10, 25, 15, 0, 25, 0, 30, 0, 10, 25, 5, 25, 15, 25, 30, 10, 0, 5, 5, 5, 0, 10, 5, 10, 25, 0, 0, 0, 30, 0, 15, 0, 15, 0, 0, 10, 5, 10, 0, 35, 0, 10, 10, 30, 30, 10, 10, 10, 25, 15, 5, 0, 25, 0, 10, 0, 25, 25, 10, 25, 5, 0, 0, 5, 5, 0, 15, 0, 10, 15, 30, 10, 25, 15, 5, 10, 0, 5, 0, 15], 0.6: [0, 5, 0, 0, 35, 0, 10, 0, 10, 5, 0, 10, 0, 0, 0, 0, 5, 25, 5, 0, 0, 0, 0, 0, 10, 0, 5, 15, 0, 0, 0, 0, 0, 10, 0, 10, 0, 5, 25, 10, 0, 0, 0, 0, 0, 10, 10, 10, 0, 0, 0, 0, 0, 10, 0, 0, 10, 10, 0, 0, 0, 0, 0, 5, 5, 0, 0, 0, 10, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 0.8: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 10, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 0, 0, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 10, 0, 0, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0], 1.0: [0, 0, 10, 0, 10, 0, 0, 0, 0, 0, 0, 0, 0, 10, 0, 0, 0, 10, 0, 0, 0, 0, 10, 0, 0, 0, 0, 0, 0, 10, 0, 0, 0, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 10, 0, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 10, 0, 0, 0, 0, 0, 0, 0]}, 0.5: {0.0: [350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350], 0.2: [53, 30, 35, 53, 53, 63, 30, 30, 35, 35, 35, 111, 35, 53, 53, 35, 35, 35, 35, 98, 121, 35, 30, 63, 30, 53, 35, 63, 35, 53, 67, 53, 30, 30, 30, 127, 112, 53, 35, 63, 35, 35, 72, 30, 63, 67, 53, 35, 111, 35, 63, 35, 53, 35, 63, 35, 35, 121, 35, 167, 35, 132, 147, 187, 223, 67, 67, 53, 63, 15, 35, 53, 240, 67, 53, 157, 233, 63, 15, 131], 0.4: [5, 5, 25, 25, 5, 15, 0, 25, 30, 10, 15, 15, 25, 0, 10, 10, 10, 0, 5, 15, 0, 0, 0, 0, 15, 25, 10, 15, 30, 25, 10, 25, 5, 25, 5, 5, 5, 0, 5, 15, 0, 5, 30, 10, 15, 25, 10, 5, 25, 10, 30, 0, 25, 30, 0, 0, 30, 10, 25, 15, 25, 5, 15, 25, 0, 0, 5, 0, 0, 30, 25, 10, 5, 10, 5, 15, 25, 15, 10, 15], 0.6: [10, 0, 0, 10, 0, 0, 0, 10, 0, 10, 0, 0, 10, 0, 25, 5, 0, 5, 0, 0, 5, 0, 0, 0, 5, 0, 0, 0, 10, 0, 5, 0, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 10, 0, 0, 10, 10, 25, 0, 0, 0, 0, 10, 0, 0, 0, 0, 10, 0, 10, 0, 0, 0, 10, 0, 0, 0, 0, 10, 10, 0, 0, 10, 0, 5, 0, 0, 0], 0.8: [0, 0, 0, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 0, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10], 1.0: [0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 0, 0, 0, 0, 0, 0, 0, 10, 0, 0, 0, 0, 0, 0, 0, 10, 10, 0, 10, 0, 0, 0, 0, 0, 0, 0, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 0, 0, 0, 0, 0, 0, 10, 0, 0, 0, 0, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0]}, 0.7: {0.0: [350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350], 0.2: [63, 35, 30, 167, 202, 35, 35, 35, 35, 131, 35, 63, 35, 35, 35, 77, 127, 30, 35, 63, 92, 30, 35, 35, 35, 116, 25, 76, 30, 35, 30, 82, 63, 53, 35, 35, 97, 131, 178, 5, 30, 35, 35, 30, 53, 35, 25, 53, 53, 25, 5, 53, 30, 35, 25, 63, 25, 30, 35, 35, 53, 35, 35, 35, 96, 35, 35, 35, 53, 30, 30, 35, 35, 72, 35, 25, 151, 35, 96, 63], 0.4: [5, 0, 25, 5, 5, 30, 0, 10, 25, 0, 10, 5, 5, 0, 0, 10, 10, 35, 10, 25, 5, 0, 0, 0, 0, 5, 5, 0, 5, 15, 0, 25, 5, 10, 10, 0, 0, 5, 5, 30, 25, 10, 5, 5, 10, 10, 25, 0, 15, 15, 25, 15, 25, 25, 5, 30, 10, 0, 25, 15, 10, 10, 0, 35, 0, 25, 10, 10, 5, 5, 0, 25, 0, 0, 0, 0, 10, 10, 10, 10], 0.6: [5, 0, 0, 0, 0, 0, 10, 0, 10, 25, 0, 0, 15, 0, 10, 0, 0, 25, 0, 0, 10, 0, 5, 0, 10, 5, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 0, 25, 0, 0, 0, 0, 0, 0, 0, 0, 0, 15, 5, 0, 10, 0, 10, 0, 10, 0, 10, 25, 5, 10, 10, 10, 0, 0, 10, 10, 10, 0, 0, 10, 0, 0, 10], 0.8: [0, 0, 10, 0, 0, 0, 0, 0, 10, 0, 0, 0, 0, 0, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 10, 0, 0, 10, 0, 10, 10, 10, 0, 0, 0, 10, 0, 0, 10, 0, 0, 0, 0, 10, 10, 0, 0, 25, 10, 0, 0, 0, 0, 0, 0, 0, 0, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 0, 5, 0, 0, 0, 0, 10, 0, 0, 10], 1.0: [10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 0, 0, 0, 0, 0, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 0, 0, 0, 0, 0, 0, 0]}, 0.9: {0.0: [350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350], 0.2: [53, 35, 25, 30, 30, 30, 30, 35, 101, 35, 63, 10, 35, 53, 30, 35, 53, 35, 30, 30, 35, 116, 111, 30, 35, 35, 35, 35, 101, 35, 30, 63, 53, 35, 35, 25, 35, 35, 35, 30, 67, 35, 152, 35, 53, 53, 5, 30, 5, 35, 35, 30, 35, 35, 97, 35, 25, 30, 30, 30, 30, 35, 30, 61, 67, 30, 63, 30, 15, 15, 53, 5, 30, 15, 59, 30, 30, 84, 15, 30], 0.4: [0, 15, 15, 15, 0, 10, 5, 30, 30, 0, 10, 0, 10, 35, 5, 25, 5, 25, 25, 10, 10, 30, 0, 20, 0, 30, 5, 25, 5, 5, 25, 25, 0, 10, 25, 5, 10, 5, 5, 0, 5, 10, 25, 10, 0, 25, 10, 30, 30, 0, 30, 25, 10, 15, 10, 10, 25, 10, 15, 0, 0, 15, 25, 0, 0, 10, 5, 25, 5, 25, 5, 30, 25, 35, 10, 30, 10, 25, 0, 0], 0.6: [10, 0, 10, 0, 10, 0, 0, 10, 0, 0, 0, 25, 0, 0, 0, 0, 0, 10, 0, 10, 0, 10, 0, 0, 0, 10, 0, 0, 0, 0, 0, 0, 0, 0, 10, 0, 0, 10, 0, 0, 10, 0, 0, 10, 25, 0, 10, 0, 0, 10, 0, 0, 5, 0, 0, 0, 10, 0, 10, 0, 0, 5, 10, 0, 0, 0, 10, 0, 10, 0, 0, 10, 0, 0, 10, 10, 0, 0, 10, 0], 0.8: [0, 0, 0, 0, 0, 10, 0, 0, 0, 10, 0, 10, 0, 0, 0, 0, 25, 0, 0, 0, 0, 10, 0, 0, 0, 0, 0, 0, 0, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 0, 10, 0, 0, 0, 0, 0, 10, 0, 10, 0, 0, 0, 0, 10, 0, 0, 0, 0, 0, 0, 10, 5, 0, 10, 0, 10, 10, 0, 0, 0, 0], 1.0: [10, 0, 0, 0, 0, 0, 0, 0, 10, 0, 0, 0, 25, 0, 10, 0, 0, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 0, 0, 0, 10, 10, 10, 10, 10, 0, 0, 0, 10, 0, 0, 0, 0, 10, 0, 0, 0, 5, 0, 0, 10, 0, 0]}}

In [37]:
results_accuracy = {}

for threshold in [0.3, 0.5, 0.7, 0.9]:
    results_accuracy[threshold] = {}
    for snr in [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]:
        results_accuracy[threshold][snr] = []
        for subdir in range(4):
            full_path = path + str(subdir) + '/20000'
            network.load_state_dict(torch.load(full_path + '/network'))
            for result in test_accuracy(False, threshold, snr, 20):
                results_accuracy[threshold][snr].append(result)

In [38]:
print(results_accuracy)

{0.3: {0.0: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 0.2: [0.9754299754299754, 0.9656019656019657, 0.9508599508599509, 0.9656019656019657, 0.9606879606879607, 0.972972972972973, 0.9680589680589681, 0.9705159705159705, 0.9631449631449631, 0.9803439803439803, 0.9582309582309583, 0.9631449631449631, 0.9631449631449631, 0.9680589680589681, 0.9705159705159705, 0.9656019656019657, 0.9582309582309583, 0.9705159705159705, 0.9656019656019657, 0.972972972972973, 0.9705159705159705, 0.9705159705159705, 0.972972972972973, 0.9606879606879607, 0.9606879606879607, 0.9754299754299754, 0.972972972972973, 0.9656019656019657, 0.9631449631449631, 0.96

In [39]:
results_reward = {}

for threshold in [0.3, 0.5, 0.7, 0.9]:
    results_reward[threshold] = {}
    for snr in [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]:
        results_reward[threshold][snr] = []
        for subdir in range(4):
            full_path = path + str(subdir) + '/20000'
            network.load_state_dict(torch.load(full_path + '/network'))
            for result in test_env(False, threshold, snr, 20):
                results_reward[threshold][snr].append(result)

In [40]:
print(results_reward)

{0.3: {0.0: [350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350, 350], 0.2: [92, 35, 156, 35, 147, 35, 167, 117, 35, 30, 35, 35, 63, 142, 30, 35, 83, 30, 67, 35, 63, 35, 35, 53, 30, 53, 30, 35, 35, 35, 30, 53, 117, 30, 35, 35, 35, 30, 35, 35, 63, 63, 53, 35, 72, 30, 35, 53, 30, 35, 30, 35, 131, 35, 5, 35, 57, 63, 172, 121, 35, 35, 116, 182, 35, 53, 136, 102, 35, 53, 35, 53, 35, 63, 5, 111, 63, 53, 35, 35], 0.4: [10, 25, 15, 0, 25, 0, 30, 0, 10, 25, 5, 25, 15, 25, 30, 10, 0, 5, 5, 5, 0, 10, 5, 10, 25, 0, 0, 0, 30, 0, 15, 0, 15, 0, 0, 10, 5, 10, 0, 35, 0, 10, 10, 30, 30, 10, 10, 10, 25, 15, 5, 0, 25, 0, 10, 0, 25, 25, 10, 25, 5, 0, 0, 5, 5, 0, 15, 0, 

In [42]:
def test_timing(its):
    for _ in range(its):
        for action in raw_actions:
            vec = 0
            for token in tokenizer(action):
                vec += word2vec_model[token]

            sampled_noise = noise.sample().numpy()
            normalized_noise = snr * np.linalg.norm(vec) * sampled_noise / np.linalg.norm(sampled_noise)            
            ground_truth = torch.Tensor(vec + normalized_noise).to(device).unsqueeze(0)

            deepcs_output = network(ground_truth, True).squeeze(0)
            list_of_words = []
            for idx in range(len(deepcs_output)):
                if deepcs_output[idx] > threshold:
                    list_of_words.append(idx)

            _, text_command = agent._select_eps_greedy_action(0, list_of_words, None)

In [44]:
tick= time.clock()
test_timing(20)
tock = time.clock()
print((tock - tick)/20)

0.3475349499999993


In [None]:
colors = ['#396ab1', '#da7c30', '#3e9651', '#cc2529', '#94823d', '#535154', '#006400', '#00FF00', '#800000', '#F08080', '#FFFF00', '#000000', '#C0C0C0']
facecolors = ['#7293cb', '#e1974c', '#84ba5b', '#d35e60', '#ccc210', '#808585']

f, axarr = pl.subplots(1, 1, figsize=(6, 3))

idx = 0
for test_name in results_reward:
    avg = [res[0] for res in results_reward[test_name]]
    std = [res[1] for res in results_reward[test_name]]
    pl.plot(snr, avg, label=test_name, color=colors[idx])
    pl.fill_between(snr, np.array(avg) - np.array(std), np.array(avg) + np.array(std), facecolor=facecolors[idx], alpha=0.2, interpolate=True)
    idx += 1

leg = pl.legend(loc='upper center', bbox_to_anchor=(0.5, -0.2), shadow=True, ncol=3, fontsize=10)
for legobj in leg.legendHandles:
    legobj.set_linewidth(3.0)
    
#pl.suptitle('Egg Quest, Minimal Action Set,\n GloVe, Training with K=all', fontsize=20, y=1.1)
pl.xlabel('SnR')
pl.ylabel('Reward')
pl.show()

In [None]:
print(words)
print(len(words))

In [2]:
from sklearn.linear_model import OrthogonalMatchingPursuit
omp = OrthogonalMatchingPursuit(n_nonzero_coefs=4)

In [28]:
def test_accuracy_omp(additional_prints, snr, its):
    runs = []
    for _ in range(its):
        accurate = 0
        for action in raw_actions:
            vec = 0
            for token in tokenizer(action):
                vec += word2vec_model[token]

            sampled_noise = noise.sample().numpy()
            normalized_noise = snr * np.linalg.norm(vec) * sampled_noise / np.linalg.norm(sampled_noise)            
            ground_truth = vec + normalized_noise

            omp.fit(agent.word_embeddings.cpu().numpy().T, ground_truth)
            coef = omp.coef_
            idx_r, = coef.nonzero()

            list_of_words = []
            for idx in idx_r:
                if coef[idx] > 0.5:
                    list_of_words.append(idx)

            _, text_command = agent._select_eps_greedy_action(0, list_of_words, None)

            if set(tokenizer(action)) == set(tokenizer(text_command)):
                accurate += 1
            elif additional_prints:
                print(tokenizer(text_command))
                print(tokenizer(action))

        runs.append(accurate * 1.0 / len(raw_actions))
    return runs

In [29]:
test_accuracy_omp(True, 0, 1)

dependence in the dictionary. The requested precision might not have been met.

  copy_X=copy_X, return_path=return_path)
dependence in the dictionary. The requested precision might not have been met.

  copy_X=copy_X, return_path=return_path)
dependence in the dictionary. The requested precision might not have been met.

  copy_X=copy_X, return_path=return_path)
dependence in the dictionary. The requested precision might not have been met.

  copy_X=copy_X, return_path=return_path)
dependence in the dictionary. The requested precision might not have been met.

  copy_X=copy_X, return_path=return_path)
dependence in the dictionary. The requested precision might not have been met.

  copy_X=copy_X, return_path=return_path)
dependence in the dictionary. The requested precision might not have been met.

  copy_X=copy_X, return_path=return_path)
dependence in the dictionary. The requested precision might not have been met.

  copy_X=copy_X, return_path=return_path)
dependence in the dictio

['out', 'and', 'rope', 'knife']
['get', 'rope', 'and', 'knife']
['railing', 'tie', 'rope', 'down']
['tie', 'rope', 'to', 'railing']
['go', 'torch']
['get', 'torch']
['to', 'and', 'torch', 'lamp']
['take', 'torch', 'and', 'lamp']
['case', 'coffin', 'in', 'to']
['put', 'coffin', 'in', 'case']
['of', 'skull', 'case', 'wait']
['put', 'skull', 'in', 'case']
['case', 'bar', 'to']
['put', 'bar', 'in', 'case']
['out', 'bag', 'and', 'knife']
['get', 'knife', 'and', 'bag']


dependence in the dictionary. The requested precision might not have been met.

  copy_X=copy_X, return_path=return_path)
dependence in the dictionary. The requested precision might not have been met.

  copy_X=copy_X, return_path=return_path)
dependence in the dictionary. The requested precision might not have been met.

  copy_X=copy_X, return_path=return_path)
dependence in the dictionary. The requested precision might not have been met.

  copy_X=copy_X, return_path=return_path)
dependence in the dictionary. The requested precision might not have been met.

  copy_X=copy_X, return_path=return_path)
dependence in the dictionary. The requested precision might not have been met.

  copy_X=copy_X, return_path=return_path)
dependence in the dictionary. The requested precision might not have been met.

  copy_X=copy_X, return_path=return_path)
dependence in the dictionary. The requested precision might not have been met.

  copy_X=copy_X, return_path=return_path)
dependence in the dictio

['put', 'case', 'bag']
['put', 'bag', 'in', 'case']
['pile', 'case', 'in', 'button']
['put', 'trunk', 'in', 'case']
['trident', 'case', 'in', 'to']
['put', 'trident', 'in', 'case']
['boat', 'tie', 'sword', 'in']
['throw', 'sceptre', 'in', 'boat']
['of', 'out', 'boat', 'wait']
['get', 'out', 'of', 'boat']
['case', 'in', 'to', 'sceptre']
['put', 'sceptre', 'in', 'case']
['case', 'in', 'to', 'pot']
['put', 'pot', 'in', 'case']
['case', 'from', 'scarab']
['put', 'scarab', 'in', 'case']
['go', 'torch']
['get', 'torch']


dependence in the dictionary. The requested precision might not have been met.

  copy_X=copy_X, return_path=return_path)
dependence in the dictionary. The requested precision might not have been met.

  copy_X=copy_X, return_path=return_path)
dependence in the dictionary. The requested precision might not have been met.

  copy_X=copy_X, return_path=return_path)
dependence in the dictionary. The requested precision might not have been met.

  copy_X=copy_X, return_path=return_path)
dependence in the dictionary. The requested precision might not have been met.

  copy_X=copy_X, return_path=return_path)
dependence in the dictionary. The requested precision might not have been met.

  copy_X=copy_X, return_path=return_path)
dependence in the dictionary. The requested precision might not have been met.

  copy_X=copy_X, return_path=return_path)
dependence in the dictionary. The requested precision might not have been met.

  copy_X=copy_X, return_path=return_path)
dependence in the dictio

['case', 'from', 'chalice']
['put', 'chalice', 'in', 'case']
['case', 'in', 'egg', 'to']
['put', 'egg', 'in', 'case']
['case', 'bauble', 'in', 'up']
['put', 'bauble', 'in', 'case']
['case', 'jewels', 'in', 'to']
['put', 'jewels', 'in', 'case']
['basket', 'in', 'torch', 'up']
['put', 'torch', 'in', 'basket']
['basket', 'in', 'screwdriver', 'up']
['put', 'screwdriver', 'in', 'basket']
['basket', 'in', 'coal', 'up']
['put', 'coal', 'in', 'basket']
['basket', 'all', 'up']
['get', 'all', 'from', 'basket']
['from', 'coal', 'shovel']
['put', 'coal', 'in', 'machine']
['basket', 'in', 'diamond', 'up']
['put', 'diamond', 'in', 'basket']
['basket', 'in', 'torch', 'up']
['put', 'torch', 'in', 'basket']
['basket', 'in', 'screwdriver', 'up']
['put', 'screwdriver', 'in', 'basket']
['basket', 'all', 'up']
['get', 'all', 'from', 'basket']
['case', 'in', 'to', 'torch']
['put', 'torch', 'in', 'case']
['case', 'bracelet', 'in', 'up']
['put', 'bracelet', 'in', 'case']
['figurine', 'case', 'in', 'up']
['put

dependence in the dictionary. The requested precision might not have been met.

  copy_X=copy_X, return_path=return_path)
dependence in the dictionary. The requested precision might not have been met.

  copy_X=copy_X, return_path=return_path)
dependence in the dictionary. The requested precision might not have been met.

  copy_X=copy_X, return_path=return_path)
dependence in the dictionary. The requested precision might not have been met.

  copy_X=copy_X, return_path=return_path)
dependence in the dictionary. The requested precision might not have been met.

  copy_X=copy_X, return_path=return_path)
dependence in the dictionary. The requested precision might not have been met.

  copy_X=copy_X, return_path=return_path)
dependence in the dictionary. The requested precision might not have been met.

  copy_X=copy_X, return_path=return_path)
dependence in the dictionary. The requested precision might not have been met.

  copy_X=copy_X, return_path=return_path)
dependence in the dictio

[0.914004914004914]