In [1]:
import copy
import glob
import os
import time
from collections import deque

import gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import sys
sys.path.insert(0,'/home/srikar/pytorch-a2c-ppo-acktr')

from a2c_ppo_acktr import algo
from a2c_ppo_acktr.arguments import get_args
from a2c_ppo_acktr.envs import make_vec_envs
from a2c_ppo_acktr.model import Policy
from a2c_ppo_acktr.storage import RolloutStorage
from a2c_ppo_acktr.utils import get_vec_normalize, update_linear_schedule
from a2c_ppo_acktr.visualize import visdom_plot


import unicodedata
import string
import re
import random
import time
import math

import torch
import torch.nn as nn
from torch.autograd import Variable
from torch import optim
import torch.nn.functional as F
import torchnlp
from torchnlp.datasets import iwslt_dataset

import gym
import gym_nmt
import rllab
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [2]:
train_data = iwslt_dataset(train=True)
USE_CUDA = True

In [3]:
SOS_token = 0
EOS_token = 1

class Lang:
    def __init__(self, name):
        self.name = name
        self.word2index = {}
        self.word2count = {}
        self.index2word = {0: "SOS", 1: "EOS"}
        self.n_words = 2 # Count SOS and EOS
        self.len_largest = 0
        self.largest = ""
      
    def index_words(self, sentence):
        sentence = sentence.split(' ')
        if len(sentence) > self.len_largest:
            self.len_largest = len(sentence)
            self.largest = sentence
        for word in sentence:
            self.index_word(word)

    def index_word(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.n_words
            self.word2count[word] = 1
            self.index2word[self.n_words] = word
            self.n_words += 1
        else:
            self.word2count[word] += 1
def unicode_to_ascii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
    )

# Lowercase, trim, and remove non-letter characters
def normalize_string(s):
    s = unicode_to_ascii(s.lower().strip())
#     s = re.sub(r"([.!?])", r" \1", s)
#     s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
    return s


In [4]:
def read_langs(lang1, lang2, reverse=False):

    if reverse:
        input_lang = Lang(lang2)
        output_lang = Lang(lang1)
    else:
        input_lang = Lang(lang1)
        output_lang = Lang(lang2)
        
    return input_lang, output_lang

def prepare_data(lang1_name, lang2_name,train_data, reverse=False):
    input_lang, output_lang = read_langs(lang1_name, lang2_name, reverse)
    print("Read %s sentence pairs" % len(train_data))
    
    
    print("Indexing words...")
    for pair in train_data:
        input_lang.index_words(pair[lang1_name])
        output_lang.index_words(pair[lang2_name])

    return input_lang, output_lang
input_lang, output_lang = prepare_data('en', 'de',train_data, True)
print(random.choice(train_data))
def indexes_from_sentence(lang, sentence):
    return [lang.word2index[word] for word in sentence.split(' ')]

def variable_from_sentence(lang, sentence):
    indexes = indexes_from_sentence(lang, sentence)
    indexes.append(EOS_token)
    var = Variable(torch.LongTensor(indexes).view(-1, 1))
#     print('var =', var)
    if USE_CUDA: var = var.cuda()
    return var

def variables_from_pair(pair,lang1,lang2):
    input_variable = variable_from_sentence(input_lang, pair[lang1])
    target_variable = variable_from_sentence(output_lang, pair[lang2])
    return (input_variable, target_variable)


Read 196884 sentence pairs
Indexing words...
{'en': 'Joy is shared. Achievement is shared.', 'de': 'Freude wird geteilt. Leistung wird geteilt.'}


In [5]:
n_epochs = 10
training_scheme = [1]*10 + [2]*20 + [3]*20
env = gym.make('nmt-v0')
env.space_init(output_lang.n_words,output_lang.len_largest,output_lang.index2word)

In [6]:
base_kwargs={'recurrent': True,'input_nwords':input_lang.n_words,'output_nwords':output_lang.n_words,'max_length':output_lang.len_largest}
actor_critic = Policy(env.observation_space.shape, env.action_space,'Attn',base_kwargs)      


In [7]:
agent = algo.PPO(actor_critic, 0.2, 4, 1,
                         0.5, 0.01, lr=int(7e-8),
                               eps=int(1e-5),
                               max_grad_norm=0.5)



In [8]:
num_steps = 30
n_epochs = 10
use_gae = False
gamma = 0.99
tau = 0.95

In [9]:
for epoch in range(n_epochs+1):
    
    n_missing_words = training_scheme[epoch]
    rollouts = RolloutStorage(num_steps*2*n_missing_words, 1,
                        env.observation_space.shape, env.action_space,
                        actor_critic.recurrent_hidden_state_size)
    
    for step in range(num_steps):
        
        training_pair = variables_from_pair(random.choice(train_data),'en','de')
        input_variable = training_pair[0]
        target_variable = training_pair[1]
        
        
        env.my_init(input_variable.cpu().numpy(), target_variable.cpu().numpy(),n_missing_words)
        obs = env.reset()
        rollouts.obs[0].copy_(torch.tensor(obs ))
        
        for n in range(2*n_missing_words+1):

            with torch.no_grad():
                if n == 0:
                    value, action, action_log_prob, recurrent_hidden_states,context = actor_critic.act(
                            rollouts.obs[step],
                            rollouts.recurrent_hidden_states[step],
                            rollouts.masks[step],first_time = True)
                else:
                    value, action, action_log_prob, recurrent_hidden_states,context = actor_critic.act(
                                rollouts.obs[step],
                                rollouts.recurrent_hidden_states[step],
                                rollouts.masks[step],first_time = False,context = context)
            
            if (n == 2*n_missing_words):
                action = torch.tensor([[EOS_token]])
            
            obs, reward, done, infos = env.step(action.cpu().numpy()[0][0])
            
            masks = torch.FloatTensor([[0.0] if done else [1.0]])

            rollouts.insert(torch.tensor(obs ), recurrent_hidden_states, action, action_log_prob, value, torch.tensor(reward), masks)

            if action == EOS_token:
                break
            
        
    next_value = 0 #Doubtful

    rollouts.compute_returns(next_value, use_gae, gamma, tau)
    value_loss, action_loss, dist_entropy = agent.update(rollouts)


    rollouts.after_update()




action probs are tensor([[4.8230e-06, 4.8228e-06, 4.8247e-06,  ..., 4.8231e-06, 4.8226e-06,
         4.8257e-06]], device='cuda:0')
action probs are tensor([[4.8230e-06, 4.8228e-06, 4.8247e-06,  ..., 4.8231e-06, 4.8226e-06,
         4.8257e-06]], device='cuda:0')
action probs are tensor([[4.8230e-06, 4.8228e-06, 4.8247e-06,  ..., 4.8231e-06, 4.8225e-06,
         4.8257e-06]], device='cuda:0')
action probs are tensor([[4.8230e-06, 4.8228e-06, 4.8247e-06,  ..., 4.8231e-06, 4.8225e-06,
         4.8257e-06]], device='cuda:0')
action probs are tensor([[4.8230e-06, 4.8228e-06, 4.8247e-06,  ..., 4.8231e-06, 4.8226e-06,
         4.8257e-06]], device='cuda:0')
action probs are tensor([[4.8230e-06, 4.8228e-06, 4.8247e-06,  ..., 4.8231e-06, 4.8226e-06,
         4.8257e-06]], device='cuda:0')
action probs are tensor([[4.8230e-06, 4.8228e-06, 4.8247e-06,  ..., 4.8231e-06, 4.8226e-06,
         4.8257e-06]], device='cuda:0')
action probs are tensor([[4.8230e-06, 4.8228e-06, 4.8247e-06,  ..., 4.8231e-

RuntimeError: cuda runtime error (59) : device-side assert triggered at /pytorch/aten/src/THC/generated/../THCReduceAll.cuh:317

In [None]:
value

In [None]:
action

In [None]:
1/4.8231e-06

In [None]:
torch.tensor(1).cuda()