In [1]:
%load_ext autoreload
%autoreload 2

from wordle_env import WordleEnv
from wrappers import SequenceWrapper, ReshapeWrapper, TensorboardSummaries
from wrappers import nature_dqn_env
import numpy as np

num_letters = 29

env = WordleEnv()
env = SequenceWrapper(env, sos_token=1)
env = ReshapeWrapper(env)
env = TensorboardSummaries(env, prefix='wordle')


[DEBUG/MainProcess] created semlock with handle 79
[DEBUG/MainProcess] created semlock with handle 80
[DEBUG/MainProcess] created semlock with handle 82


In [2]:
from tokenizer import Tokenizer

tokenizer = Tokenizer()
tokenizer.guess_state2index['<RIGHT>']

3

In [3]:
env.reset()

array([[[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]])

In [18]:
data = list()

In [33]:
import torch
from torch.nn.utils.rnn import pad_sequence

In [34]:
data_size = 100000
voc_size = len(env.game_ans_matrix)

for _ in range(data_size):
    idx = np.random.choice(voc_size)
    action = env.game_ans_matrix[idx]
    obs, rew, done, info = env.step(action)    
    data.append(torch.tensor(obs))

tensor_data = torch.cat(data, dim=0)

In [35]:
(tensor_data[:, 1, :] == tokenizer.guess_state2index['<RIGHT>']).sum()

tensor(108240)

In [36]:
true_letters = tensor_data[:, 0, :][tensor_data[:, 1, :] == tokenizer.guess_state2index['<RIGHT>']]

In [37]:
lens = (tensor_data[:, 1, :] == tokenizer.guess_state2index['<RIGHT>']).sum(dim=1)

In [38]:
targets = []
pos = 0

for i in range(len(tensor_data)):
    targets.append(true_letters[pos:pos + lens[i]])
    pos += lens[i]

In [39]:
tensor_data = tensor_data[lens > 0]
targets = pad_sequence(targets, batch_first=True)[lens > 0, :]

In [40]:
tensor_data.shape

torch.Size([57595, 2, 36])

In [41]:
targets.shape

torch.Size([57595, 9])

In [49]:
targets = targets[:, :np.quantile(lens, 0.98).astype(np.int64)]

In [50]:
targets

tensor([[21,  3,  0,  0],
        [21,  3,  0,  0],
        [21,  3,  0,  0],
        ...,
        [22,  0,  0,  0],
        [22,  0,  0,  0],
        [22,  0,  0,  0]])

In [89]:
import torch.nn as nn
import torch
from model import Encoder, Decoder, AttentionLayer

In [214]:
class Model(nn.Module):
    def __init__(self, letter_tokens, guess_tokens, emb_dim, hid_dim, output_dim, game_voc_matrix, output_len, sos_token, dropout=0.2):
        super().__init__()
        
        self.encoder = Encoder(letter_tokens, guess_tokens, emb_dim, hid_dim, dropout)
        self.decoder = Decoder(output_dim, emb_dim, hid_dim, dropout)
        self.attention = AttentionLayer(hid_dim)

        self.logit_head = nn.Linear(hid_dim, output_dim)
        
        self.letter_tokens = letter_tokens
        self.game_voc_matrix = game_voc_matrix
        self.output_len = output_len
        self.sos_token = sos_token
        self.debug_mode = False

    def debug(self, mode=False):
        self.debug_mode = mode
    
    def forward(self, letter_seq, state_seq, targets=None):
        """
            inputs:
                letter_seq: (batch_size x sequence_length)
                state_seq: (batch_size x sequence_length)
                
            outputs:
                
        """
        
        maxlen = letter_seq.shape[1]
        lengths = (letter_seq != 0).sum(axis=-1)
        mask = (torch.arange(maxlen)[None, :] < lengths[:, None]).bool()

        # tensor to store decoder outputs
        batch_size = letter_seq.shape[0]
        
        output_len = self.output_len if targets is None else targets.shape[1]
        logits = torch.zeros(batch_size, output_len, self.letter_tokens)

        encoder_hiddens, encoder_cells = self.encoder(letter_seq, state_seq)        
        hidden, cell = encoder_hiddens[:, -1:, :], encoder_cells[:, -1:, :]
        hidden, cell = hidden.permute(dims=(1, 0, 2)), cell.permute(dims=(1, 0, 2))

        # first input to the decoder is the <sos> tokens
        input = torch.full(size=(batch_size,), fill_value=self.sos_token)
        
        actions = torch.zeros(size=(batch_size, self.output_len), dtype=torch.long)
        log_probs = torch.zeros(size=(batch_size,))
        
        if self.debug_mode:
            fig, ax = plt.subplots(1, self.output_len)
        
        for t in range(1, output_len):

            # cur_logits: (batch_size, num_classes)
            # actions: (batch_size,)
            _, hidden, cell = self.decoder(input, hidden, cell)
            
            decoder_hidden = hidden
            attentive_hidden = self.attention(encoder_hiddens, decoder_hidden.permute(dims=(1, 0, 2)), encoder_hiddens, mask).squeeze(1)
            
            if self.debug_mode:
                map_reshaped = self.attention.attention_map.squeeze().reshape(6, 6).detach().numpy()
                ax[t - 1].imshow(map_reshaped)
            
            cur_logits = self.logit_head(attentive_hidden)
            logits[:, t, :] = cur_logits
            
            if targets is not None:
                input = targets[:, t - 1]
            else:
                input = cur_logits.argmax(1)
                actions[:, t - 1] = input

        if self.debug_mode:
            plt.show()

        return {
            "actions": actions.cpu().numpy(),
            "logits": logits,
            "log_probs": log_probs
        }
    
    def act(self, inputs, targets=None):
        '''
        input:
            inputs - numpy array, (batch_size x sequences x sequence_length)
        output: dict containing keys ['actions', 'logits', 'log_probs', 'values']:
            'actions' - selected actions, numpy, (batch_size, sequence_length)
            'log_probs' - log probs of selected actions, tensor, (batch_size)
            'values' - critic estimations, tensor, (batch_size)
        '''
        inputs = torch.LongTensor(inputs)
        letter_tokens, state_tokens = inputs[:, 0, :], inputs[:, 1, :]
        outputs = self(letter_tokens, state_tokens, targets)
        return outputs

In [215]:
policy = Model(
    len(tokenizer.index2letter), 
    len(tokenizer.index2guess_state), 
    32, 128, 
    len(tokenizer.index2letter), 
    output_len=5, 
    sos_token=1, 
    game_voc_matrix=env.game_voc_matrix
)

In [216]:
policy.act(tensor_data[:10], targets[:10])['logits'].shape

torch.Size([10, 4, 29])

In [217]:
from torch.utils.data import TensorDataset, DataLoader
from torch.optim import Adam

In [218]:
dataset = TensorDataset(tensor_data, targets)
loader = DataLoader(dataset, batch_size=32, shuffle=True)

In [219]:
optimizer = Adam(policy.parameters(), lr=1e-3)

In [230]:
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.letter2index['<PAD>'])

for epoch in range(5):
    for batch_x, batch_y in loader:
        logits = policy.act(batch_x, batch_y)['logits']
        optimizer.zero_grad()
        loss = criterion(logits.permute(0, 2, 1), batch_y)
        loss.backward()        
        optimizer.step()
        
        print(f"Loss = {loss.item()}", end="\r")
    print()

Loss = 1.6858958005905151

KeyboardInterrupt: 

In [233]:
obs.shape

(2, 36)

tensor([[[ 1,  6,  3,  6,  6, 27,  1, 15, 17, 23, 21,  7,  1, 21, 10,  3, 20,
          13,  1,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
           0,  0],
         [ 1,  5,  5,  4,  5,  5,  1,  4,  5,  5,  4,  5,  1,  3,  5,  3,  5,
           5,  1,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
           0,  0]]])

In [262]:
i = np.random.randint(10000)
policy.act(tensor_data[i:i + 1])

{'actions': array([[7, 7, 7, 7, 0]]),
 'logits': tensor([[[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000],
          [-1.7884, -2.0632, -1.2262, -0.1962, -1.5051, -1.7639, -1.2554,
            6.2180,  0.4405,  0.0379, -1.8652, -1.0941, -0.4712, -0.4245,
           -0.8784, -0.0734,  0.4749,  0.2367, -1.2355, -0.9595, -0.9706,
            0.4247, -1.2451, -0.8168,  0.6191, -0.0391, -0.0844, -0.9520,
           -1.1403],
          [-1.9342, -2.3574, -1.3829,  1.5797, -0.9328, -2.1390, -1.0152,
            2.0759,  0.7864,  1.4674, -0.1069,  0.2686, -0.4730, -1.0109,
           -0.0904,  0.1697,  0.4819,  0.0960, -1.3320, -0.4715, -1.6435,
            0.8346, -2.1798, -0.4123,  0.1721,  0.3797, -0.3087, -0.2949,
           -0.6558],
 

In [263]:
obs = tensor_data[i]

In [264]:
obs[0, obs[1] == tokenizer.guess_state2index['<RIGHT>']]

tensor([7])

In [265]:
(4 ** 1356 - 9 ** 4824) % 31

0

In [266]:
(4 ** 6 - 9 ** 24) % 31

0

In [268]:
(2 ** 2 - 3 ** 18) % 31

3

In [125]:
logits.shape

torch.Size([32, 4, 29])

In [84]:
targets.shape

torch.Size([57595, 4])

In [69]:
targets[:4, 1]

tensor([3, 3, 3, 0])