In [1]:
%load_ext autoreload
%autoreload 2

import numpy as np
from wordle_env import WordleEnv
import torch
import torch.nn as nn
import torch.nn.functional as F




In [72]:
def get_allowed_letters(word_matrix, word_mask, position):
    """
        word_matrix: torch.Tensor of size (num_words, word_length)
        word_mask: torch.Tensor of size (batch_size, num_words)
        position: int
        
        returns
        
        letters_mask: (batch_size, num_letters) -- mask of possible letters
    """
    batch_size = word_mask.size(0)
    word_matrix_expanded = word_matrix[:, position].unsqueeze(0).expand(batch_size, -1)
    
    # print(word_matrix[:, position].shape)
    # print(word_matrix_expanded.shape)
    # print(word_matrix[:, position].unsqueeze(1).shape)
    
    word_matrix_masked = (word_matrix_expanded * word_mask).long()    
    letter_mask = torch.full(fill_value=False, size=(batch_size, num_letters))
    letter_mask.scatter_(index=word_matrix_masked, dim=1, value=True)
    letter_mask[:, 0] = 0
    return letter_mask
    return letter_mask

In [144]:
from torch.distributions import Categorical

class Encoder(nn.Module):
    def __init__(self, letter_tokens, guess_tokens, emb_dim, hid_dim, dropout):
        super().__init__()

        self.hid_dim = hid_dim
        
        self.letter_embedding = nn.Embedding(letter_tokens, emb_dim)
        self.guess_state_embedding = nn.Embedding(guess_tokens, emb_dim)

        self.rnn = nn.LSTM(emb_dim, hid_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, letter_seq, state_seq):
        letters_embedded = self.dropout(self.letter_embedding(letter_seq))
        states_embedded = self.dropout(self.guess_state_embedding(state_seq))

        outputs, (hidden, cell) = self.rnn(letters_embedded + states_embedded)
        
        #outputs = [src len, batch size, hid dim * n directions]
        #hidden = [n layers * n directions, batch size, hid dim]
        #cell = [n layers * n directions, batch size, hid dim]
        
        #outputs are always from the top hidden layer
        
        return hidden, cell


class Decoder(nn.Module):
    def __init__(self, output_dim, emb_dim, hid_dim, dropout):
        super().__init__()
        
        self.output_dim = output_dim
        self.hid_dim = hid_dim
        
        self.embedding = nn.Embedding(output_dim, emb_dim)
        self.rnn = nn.LSTM(emb_dim, hid_dim, dropout=dropout)        
        self.fc_out = nn.Linear(hid_dim, output_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, input, hidden, cell):
        input = input.unsqueeze(0)
        embedded = self.dropout(self.embedding(input))
        output, (hidden, cell) = self.rnn(embedded, (hidden, cell))
                
        prediction = self.fc_out(output.squeeze(0))
        return prediction, hidden, cell



class RNNAgent(nn.Module):
    def __init__(self, letter_tokens, guess_tokens, emb_dim, hid_dim, output_dim, game_voc_matrix, output_len, sos_token, dropout=0.2):
        super().__init__()
        
        self.encoder = Encoder(letter_tokens, guess_tokens, emb_dim, hid_dim, dropout)
        self.decoder = Decoder(output_dim, emb_dim, hid_dim, dropout)

        modules = [nn.Linear(hid_dim, hid_dim), nn.ReLU(), nn.Linear(hid_dim, 1)]
        self.V_head = nn.Sequential(*modules)
        
        self.letter_tokens = letter_tokens
        self.game_voc_matrix = game_voc_matrix
        self.output_len = output_len
        self.sos_token = sos_token
    
    def forward(self, letter_seq, state_seq):        
        # tensor to store decoder outputs
        batch_size = letter_seq.shape[1]
        logits = torch.zeros(self.output_len + 1, batch_size, self.letter_tokens)

        hidden, cell = self.encoder(letter_seq, state_seq)

        # compute V
        values = self.V_head(hidden.squeeze())

        # first input to the decoder is the <sos> tokens
        input = torch.full(size=(batch_size,), fill_value=self.sos_token)
        
        letter_mask = torch.full(size=(batch_size, self.letter_tokens), fill_value=True)
        word_mask = torch.full(size=(batch_size, self.game_voc_matrix.shape[0]), fill_value=True)

        # logits: (seq_length, batch_size, num_classes)
        
        actions = torch.zeros(size=(batch_size, self.output_len), dtype=torch.long)
        log_probs = torch.zeros(size=(batch_size,))
        for t in range(1, self.output_len + 1):

            # cur_logits: (batch_size, num_classes)
            # actions: (batch_size,)
            cur_logits, hidden, cell = self.decoder(input, hidden, cell)
            logits[t] = cur_logits

            probs = F.softmax(cur_logits, dim=-1)

            allowed_letters = get_allowed_letters(self.game_voc_matrix, word_mask, t-1)            
            probs[~allowed_letters] = 0.0
            actions_t = Categorical(probs=probs).sample()
            
            word_mask = word_mask & (self.game_voc_matrix[:, t - 1].unsqueeze(0) == actions_t.unsqueeze(1))

            # keep which words are acceptable
            cur_log_probs = torch.log(probs[range(batch_size), actions_t].clip(min=1e-12)).squeeze()

            # letters_allowed_count = allowed_letters.sum(axis=-1)
            # log_probs[letters_allowed_count > 1] += cur_log_probs[letters_allowed_count > 1]
            log_probs += cur_log_probs
            
            actions[:, t-1] = actions_t
            input = actions_t

        return {
            "actions": actions.cpu().numpy(),
            # "logits": logits,
            "log_probs": log_probs,
            "values": values,
        }
    
    def act(self, inputs):
        '''
        input:
            inputs - numpy array, (batch_size x channels x width x height)
        output: dict containing keys ['actions', 'logits', 'log_probs', 'values']:
            'actions' - selected actions, numpy, (batch_size)
            'log_probs' - log probs of selected actions, tensor, (batch_size)
            'values' - critic estimations, tensor, (batch_size)
        '''
        inputs = torch.LongTensor(inputs)
        letter_tokens, state_tokens = inputs[0], inputs[1]
        outputs = self(letter_tokens, state_tokens)
        return outputs

In [145]:
# test get_allowed_letters
from wrappers import SequenceWrapper

num_letters = 29

env = WordleEnv()
env = SequenceWrapper(env, sos_token=1)

word_mask = torch.tensor([[True, False, False], [False, True, False]])

present_letters = get_allowed_letters(torch.from_numpy(env.game_voc_matrix), word_mask, 3).to(torch.long)
present_letters

tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0,
         0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0]])

In [1]:
from wrappers import nature_dqn_env
import numpy as np

env_2 = nature_dqn_env(nenvs=2)




[INFO/Process-1] child process calling self.run()
[INFO/Process-1] child process calling self.run()
[INFO/Process-2] child process calling self.run()
[INFO/Process-2] child process calling self.run()
[DEBUG/MainProcess] created semlock with handle 88
[DEBUG/MainProcess] created semlock with handle 89
[DEBUG/MainProcess] created semlock with handle 90


In [15]:
env_2.step(np.array([5, 5, 5, 5, 5]).reshape(1, -1).repeat(2, axis=0))

(array([[[[0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0]],
 
         [[0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0]]],
 
 
        [[[0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0]],
 
         [[0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0]]]], dtype=int32),
 array([0., 0.]),
 array([ True,  True]),
 ({}, {}))

In [128]:
from tokenizer import Tokenizer

tokenizer = Tokenizer()
game_voc_matrix = torch.FloatTensor(env.game_voc_matrix)
agent = RNNAgent(len(tokenizer.index2letter), len(tokenizer.index2guess_state), 64, 32, len(tokenizer.index2letter), output_len=5, sos_token=1, game_voc_matrix=game_voc_matrix)

In [129]:
obs = env.reset()
letter_tokens = torch.LongTensor(obs[0].reshape(-1, 1))
state_tokens = torch.LongTensor(obs[1].reshape(-1, 1))

In [130]:
agent_output = agent(letter_tokens, state_tokens).values()
agent_output

dict_values([array([[ 5, 20,  3, 16,  7]]), tensor([-16.4809], grad_fn=<AddBackward0>), tensor([0.0550], grad_fn=<AddBackward0>)])

In [131]:
from runners import EnvRunner

In [132]:
runner = EnvRunner(env, agent, nsteps=6)

In [133]:
runner.get_next()

IndexError: tuple index out of range

In [100]:
num_episodes = 10

for _ in range(num_episodes):
    obs = env.reset()
    done = False
    while not done:
        letter_tokens, state_tokens = obs
        letter_tokens, state_tokens = torch.from_numpy(letter_tokens), torch.from_numpy(state_tokens)
        letter_tokens, state_tokens = letter_tokens.reshape(-1, 1), state_tokens.reshape(-1, 1)
        action, logit = agent(letter_tokens, state_tokens)
        if np.random.rand() > 0.5:
            break
        obs, reward, done, info = env.step(action.squeeze().numpy())

In [101]:
obs

array([[ 1, 21, 25, 17, 20,  6,  1,  0,  0,  0,  0,  0,  1,  0,  0,  0,
         0,  0,  1,  0,  0,  0,  0,  0,  1,  0,  0,  0,  0,  0,  1,  0,
         0,  0,  0,  0],
       [ 1,  1,  1,  1,  1,  1,  1,  0,  0,  0,  0,  0,  1,  0,  0,  0,
         0,  0,  1,  0,  0,  0,  0,  0,  1,  0,  0,  0,  0,  0,  1,  0,
         0,  0,  0,  0]])

In [449]:
env.game_voc_matrix

array([[21, 25, 17, 20,  6],
       [ 5, 20,  3, 16,  7],
       [18, 14,  3, 22,  7]], dtype=int32)