In [1]:
import random
import torch
import numpy as np
import gensim
import gensim.downloader as glove_api
import os
import io

from matplotlib import pyplot as pl
import pickle

from ZorkGym.text_utils.text_parser import BagOfWords, Word2Vec, TextParser, tokenizer
from agents.OMP_DDPG import OMPDDPG

In [2]:
task = 'troll'

In [3]:
# def load_vec(emb_path, nmax=50000):
#     vectors = []
#     word2id = {}
#     with io.open(emb_path, 'r', encoding='utf-8', newline='\n', errors='ignore') as f:
#         next(f)
#         for i, line in enumerate(f):
#             word, vect = line.rstrip().split(' ', 1)
#             vect = np.fromstring(vect, sep=' ')
#             assert word not in word2id, 'word found twice'
#             vectors.append(vect)
#             word2id[word] = len(word2id)
#             if len(word2id) == nmax:
#                 break
#     id2word = {v: k for k, v in word2id.items()}
#     embeddings = np.vstack(vectors)
#     return embeddings, id2word, word2id

In [4]:
# src_path = '/home/chen/ZorkDiscreteDDPG/data/wiki.multi.en.vec'
# nmax = 500000  # maximum number of word embeddings to load

# embeddings, id2word, word2id = load_vec(src_path, nmax)

In [5]:
if torch.cuda.is_available():
    device = torch.device('cuda')
    torch.backends.cudnn.enabled = False
else:
    device = torch.device('cpu')

In [6]:
def word2vec_padding(list_of_embeddings, length, embedding_length):
    zero_vec = np.zeros(embedding_length)
    for _ in range(length - len(list_of_embeddings)):
        list_of_embeddings.append(zero_vec)
    return list_of_embeddings[:length]


def word2vec_sum(list_of_embeddings, embedding_length):
    ret_value = np.zeros(embedding_length)
    for embedding in list_of_embeddings:
        ret_value += embedding
    return ret_value

class OneHotParser(TextParser):
    def __init__(self, vocabulary, type_func):
        """

        :param vocabulary: List of strings representing the vocabulary.
        :param type_func: Function which converts the output to the desired type, e.g. np.array.
        """
        self.vocab = vocabulary
        self.vocab_size = len(self.vocab)
        TextParser.__init__(self, type_func)

    def __call__(self, x):
        one_hot = np.zeros((len(x), self.vocab_size))  # +1 for out of vocabulary tokens.
        for idx, token_list in enumerate(x):
            sentence = ' '.join(token_list)
            vocab_idx = self.vocab.index(sentence)
            one_hot[idx, vocab_idx] = 1

        return self.convert_type(one_hot)

def load_list_from_file(file_path):
    with open(file_path) as file:
        content = file.readlines()
    ret = []
    for elem in content:
        clean_elem = elem.strip()
        if len(clean_elem) > 0:
            ret.append(clean_elem)
    return ret

In [7]:
with open(os.getcwd() + '/data/zork_walkthrough_full.txt', 'rb') as f:
    data = pickle.load(f)

raw_actions = data['actions']
raw_states = data['states']

with open(os.getcwd() + '/data/zork_walkthrough_egg_troll.txt', 'rb') as f:
    data = pickle.load(f)
    
for action in data['actions']:
    raw_actions.append(action)

for state in data['states']:
    raw_states.append(state)
    
actions = set()
for action in raw_actions:
    for token in tokenizer(action):
        actions.add(token)
actions = list(actions)

In [10]:
verbs = ['go', 'take', 'open', 'grab', 'run', 'walk', 'climb', 'kill', 'light', 'get']

#basic_actions = ['open', 'egg', 'east', 'west', 'north', 'south', 'go', 'up', 'down', 'look', 'take']
basic_actions = ['open', 'egg', 'north', 'climb', 'tree', 'take']

extended_actions = ['grab', 'run', 'climb', 'walk', 'go', 'south', 'east', 'west']

basic_objects = ['egg', 'door', 'tree', 'leaves', 'nest']

obj_ext1 = ['bag', 'bottle', 'rope', 'sword', 'lantern', 'knife', 'mat', 'mailbox',
            'rug', 'case', 'axe', 'diamond', 'leaflet', 'news', 'brick']
action_ext1 = ['enter', 'open the window', 'turn lamp on', 'move rug', 'open trap door', 'hit troll with sword']

random_words = ['bring', 'wait', 'test', 'heave', 'squat', 'garbage', 'you', 'no', 'year']

def create_actions():
#     words = list()
#     words.append('')
#     for action in actions:
#         tokens = tokenizer(action)
#         for token in tokens:
#             if token not in words:
#                 words.append(token)

#     sentences = list()
#     for i, word1 in enumerate(words):
#         for word2 in words[i + 1:]:
#             if word1 in verbs:
#                 sentences.append(word1 + ' ' + word2)
#             else:
#                 sentences.append(word2 + ' ' + word1)
                
    words = set()
#     for action in sentences:
#         for word in tokenizer(action):
#             words.add(word)

    for action in raw_actions:
        for token in tokenizer(action):
            words.add(token)

    action_vocabulary = {}
    for word in words:
        action_vocabulary[word] = word2vec_model[word]
    action_vocabulary[''] = [0 for _ in range(len(action_vocabulary['north']))]

    embedding_size = len(action_vocabulary['open'])
    
    return actions, action_vocabulary, embedding_size, words

In [11]:
word2vec_model = glove_api.load('glove-wiki-gigaword-50')
embedding_size = word2vec_model.vector_size
word2vec_parser = Word2Vec(type_func=lambda x: torch.FloatTensor(x).to(device).unsqueeze(0),
                           word2vec_model=word2vec_model,
                           return_func=lambda x: word2vec_padding(x, 65, embedding_size))

In [13]:
def test(additional_prints, test_iterations):
    total_reward = 0
    iteration = 0
    with torch.no_grad():
        while iteration < test_iterations:
            idx = 0
            try:
                obs = agent.env.reset()
                
                done = False

                full_state = torch.zeros((agent.history_size,
                                          2,
                                          agent.input_width,
                                          agent.input_length), dtype=torch.float32).to(agent.device)
                
                episode_reward = 0
                while not done:
                    obs = agent._parse_state(obs).view(2, agent.input_width, agent.input_length)
                    full_state[:agent.history_size - 1] = full_state[1:]
                    full_state[-1] = obs

                    if additional_prints:
                        print('state:')
                        agent.env.render()
                        #print(data['states'][idx])
                        #print(torch.norm(full_state-word2vec_parser(data['states'][idx])))
                    idx += 1
                    
                    agent_output = agent.network[0](full_state)
                    action, text_command = agent._get_action_beam_search(agent_output, full_state, 0, False, number_of_neighbors, False)
                    if additional_prints:
                        print('action:')
                        #print(agent_output)
                        #print(action)
                        print(text_command)

                    obs, reward, done, has_won = agent.env.step(text_command)

                    episode_reward += reward

                if additional_prints:
                    agent.env.render()

                total_reward += episode_reward
                iteration += 1
            except EnvironmentError:
                print('There was some issue with the Zork test env.')

    return total_reward * 1.0 / test_iterations

# Default Agent

In [14]:
number_of_neighbors=5

actions, action_vocabulary, embedding_size, words = create_actions()

In [15]:
agent = OMPDDPG(actions=action_vocabulary,
                state_parser=word2vec_parser,
                embedding_size=embedding_size,
                input_length=embedding_size,
                input_width=65,
                history_size=1,
                model_type='CNN',
                device=device,
                pomdp_mode=False,
                loss_weighting=1.0,
                linear=False,
                improved_omp=False,
                task=task)

In [16]:
path = os.getcwd() + '/imitation_agent_' + task + '_0.0/0/10000/'
agent.network[0].load_state_dict(torch.load(path + '/actor'))
agent.network[1].load_state_dict(torch.load(path + '/critic'))

In [17]:
result = test(True, 1)

state:


West of House
You are standing in an open field west of a white house, with a boarded front door.
There is a small mailbox here.


[]
action:
go  north
state:
North of House
You are facing the north side of a white house. There is no door here, and all the windows are boarded up. To the north a narrow
path winds through the trees.


[]
action:
go  east
state:
Behind House
You are behind the white house. A path leads into the forest to the east. In one corner of the house there is a small window
which is slightly ajar.


[]
action:
open  window
state:
Behind House
You are behind the white house. A path leads into the forest to the east. In one corner of the house there is a small window
which is open.


[]
action:
go  west
state:
Kitchen
You are in the kitchen of the white house. A table seems to have been used recently for the preparation of food. A passage leads
to the west and a dark staircase can be seen leading upward. A dark chimney leads down and to the east is a small w

action:
 to
state:
Forest
This is a forest, with trees in all directions. To the east, there appears to be sunlight.
There is a sword here.


[]
action:
 to
state:
Forest
This is a forest, with trees in all directions. To the east, there appears to be sunlight.
There is a sword here.


[]
action:
 to
state:
Forest
This is a forest, with trees in all directions. To the east, there appears to be sunlight.
There is a sword here.


[]
action:
 to
state:
Forest
This is a forest, with trees in all directions. To the east, there appears to be sunlight.
There is a sword here.


[]
action:
 to
state:
Forest
This is a forest, with trees in all directions. To the east, there appears to be sunlight.
There is a sword here.


[]
action:
 to
state:
Forest
This is a forest, with trees in all directions. To the east, there appears to be sunlight.
There is a sword here.


[]
action:
 to
state:
Forest
This is a forest, with trees in all directions. To the east, there appears to be sunlight.
There is a sw

action:
 to
state:
Forest
This is a forest, with trees in all directions. To the east, there appears to be sunlight.
There is a sword here.


[]
action:
 to
state:
Forest
This is a forest, with trees in all directions. To the east, there appears to be sunlight.
There is a sword here.


[]
action:
 to
state:
Forest
This is a forest, with trees in all directions. To the east, there appears to be sunlight.
There is a sword here.


[]
action:
 to
state:
Forest
This is a forest, with trees in all directions. To the east, there appears to be sunlight.
There is a sword here.


[]
action:
 to
state:
Forest
This is a forest, with trees in all directions. To the east, there appears to be sunlight.
There is a sword here.


[]
action:
 to
state:
Forest
This is a forest, with trees in all directions. To the east, there appears to be sunlight.
There is a sword here.


[]
action:
 to
state:
Forest
This is a forest, with trees in all directions. To the east, there appears to be sunlight.
There is a sw

action:
 to
state:
Forest
This is a forest, with trees in all directions. To the east, there appears to be sunlight.
There is a sword here.


[]
action:
 to
state:
Forest
This is a forest, with trees in all directions. To the east, there appears to be sunlight.
There is a sword here.


[]
action:
 to
state:
Forest
This is a forest, with trees in all directions. To the east, there appears to be sunlight.
There is a sword here.


[]
action:
 to
state:
Forest
This is a forest, with trees in all directions. To the east, there appears to be sunlight.
There is a sword here.


[]
action:
 to
state:
Forest
This is a forest, with trees in all directions. To the east, there appears to be sunlight.
There is a sword here.


[]
action:
 to
state:
Forest
This is a forest, with trees in all directions. To the east, there appears to be sunlight.
There is a sword here.


[]
action:
 to
state:
Forest
This is a forest, with trees in all directions. To the east, there appears to be sunlight.
There is a sw

In [18]:
for state in data['states']:
    print(state)
    print()

[['west', 'of', 'house', 'you', 'are', 'standing', 'in', 'an', 'open', 'field', 'west', 'of', 'a', 'white', 'house', 'with', 'a', 'boarded', 'front', 'door', 'there', 'is', 'a', 'small', 'mailbox', 'here'], []]

[['north', 'of', 'house', 'you', 'are', 'facing', 'the', 'north', 'side', 'of', 'a', 'white', 'house', 'there', 'is', 'no', 'door', 'here', 'and', 'all', 'the', 'windows', 'are', 'boarded', 'up', 'to', 'the', 'north', 'a', 'narrow', 'path', 'winds', 'through', 'the', 'trees'], []]

[['forest', 'path', 'this', 'is', 'a', 'path', 'winding', 'through', 'a', 'dimly', 'lit', 'forest', 'the', 'path', 'heads', 'north', 'south', 'here', 'one', 'particularly', 'large', 'tree', 'with', 'some', 'low', 'branches', 'stands', 'at', 'the', 'edge', 'of', 'the', 'path'], []]

[['up', 'a', 'tree', 'you', 'are', 'about', 'feet', 'above', 'the', 'ground', 'nestled', 'among', 'some', 'large', 'branches', 'the', 'nearest', 'branch', 'above', 'you', 'is', 'above', 'your', 'reach', 'beside', 'you', 'o

In [19]:
print(result)

0.0
