In [1]:
%load_ext autoreload
%autoreload 2

import numpy as np
from wordle_env import WordleEnv
import torch
import torch.nn as nn
import torch.nn.functional as F




In [2]:
torch.autograd.set_detect_anomaly(True)

<torch.autograd.anomaly_mode.set_detect_anomaly at 0x7fda5009e970>

In [3]:
DEVICE = torch.device('cpu')

In [4]:
def get_allowed_letters(word_matrix, word_mask, position):
    """
        word_matrix: torch.Tensor of size (num_words, word_length)
        word_mask: torch.Tensor of size (batch_size, num_words)
        position: int
        
        returns
        
        letters_mask: (batch_size, num_letters) -- mask of possible letters
    """
    batch_size = word_mask.size(0)
    word_matrix_expanded = word_matrix[:, position].unsqueeze(0).expand(batch_size, -1)
    
    # print(word_matrix[:, position].shape)
    # print(word_matrix_expanded.shape)
    # print(word_matrix[:, position].unsqueeze(1).shape)
    
    word_matrix_masked = (word_matrix_expanded * word_mask).long()    
    letter_mask = torch.full(fill_value=False, size=(batch_size, num_letters))
    letter_mask = letter_mask.scatter(index=word_matrix_masked, dim=1, value=True)
    letter_mask[:, 0] = 0
    return letter_mask

In [19]:
from torch.distributions import Categorical

class Encoder(nn.Module):
    def __init__(self, letter_tokens, guess_tokens, emb_dim, hid_dim, dropout):
        super().__init__()

        self.hid_dim = hid_dim
        
        self.letter_embedding = nn.Embedding(letter_tokens, emb_dim)
        self.guess_state_embedding = nn.Embedding(guess_tokens, emb_dim)

        self.rnn = nn.LSTM(emb_dim, hid_dim, batch_first=True)
        self.dropout = nn.Dropout(dropout)

    def forward(self, letter_seq, state_seq):
        letters_embedded = self.dropout(self.letter_embedding(letter_seq))
        states_embedded = self.dropout(self.guess_state_embedding(state_seq))

        outputs, (hidden, cell) = self.rnn(letters_embedded + states_embedded)
        
        #outputs = [src len, batch size, hid dim * n directions]
        #hidden = [n layers * n directions, batch size, hid dim]
        #cell = [n layers * n directions, batch size, hid dim]
        
        #outputs are always from the top hidden layer
        
        return hidden, cell


class Decoder(nn.Module):
    def __init__(self, output_dim, emb_dim, hid_dim, dropout):
        super().__init__()
        
        self.output_dim = output_dim
        self.hid_dim = hid_dim
        
        self.embedding = nn.Embedding(output_dim, emb_dim)
        self.rnn = nn.LSTM(emb_dim, hid_dim, dropout=dropout, batch_first=True)        
        self.fc_out = nn.Linear(hid_dim, output_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, input, hidden, cell):
        input = input.unsqueeze(0)
        embedded = self.dropout(self.embedding(input))
        output, (hidden, cell) = self.rnn(embedded, (hidden, cell))
                
        prediction = self.fc_out(output.squeeze(0))
        return prediction, hidden, cell



class RNNAgent(nn.Module):
    def __init__(self, letter_tokens, guess_tokens, emb_dim, hid_dim, output_dim, game_voc_matrix, output_len, sos_token, dropout=0.2):
        super().__init__()
        
        self.encoder = Encoder(letter_tokens, guess_tokens, emb_dim, hid_dim, dropout)
        self.decoder = Decoder(output_dim, emb_dim, hid_dim, dropout)

        modules = [nn.Linear(hid_dim, hid_dim), nn.ReLU(), nn.Linear(hid_dim, 1)]
        self.V_head = nn.Sequential(*modules)
        
        self.letter_tokens = letter_tokens
        self.game_voc_matrix = game_voc_matrix
        self.output_len = output_len
        self.sos_token = sos_token
    
    def forward(self, letter_seq, state_seq):
        """
            inputs:
                letter_seq: (batch_size x num_sequences x sequence_length)
                state_seq: (batch_size x num_sequences x sequence_length)
                
            outputs:
                
        """
        # tensor to store decoder outputs
        batch_size = letter_seq.shape[0]
        logits = torch.zeros(batch_size, self.output_len + 1, self.letter_tokens)

        hidden, cell = self.encoder(letter_seq, state_seq)

        # compute V
        values = self.V_head(hidden.squeeze())

        # first input to the decoder is the <sos> tokens
        input = torch.full(size=(batch_size,), fill_value=self.sos_token)
        
        letter_mask = torch.full(size=(batch_size, self.letter_tokens), fill_value=True)
        word_mask = torch.full(size=(batch_size, self.game_voc_matrix.shape[0]), fill_value=True)

        # logits: (seq_length, batch_size, num_classes)
        
        actions = torch.zeros(size=(batch_size, self.output_len), dtype=torch.long)
        log_probs = torch.zeros(size=(batch_size,))
        for t in range(1, self.output_len + 1):

            # cur_logits: (batch_size, num_classes)
            # actions: (batch_size,)
            cur_logits, hidden, cell = self.decoder(input, hidden, cell)
            logits[:, t, :] = cur_logits

            probs = F.softmax(cur_logits, dim=-1)

            allowed_letters = get_allowed_letters(self.game_voc_matrix, word_mask, t-1)            
            probs = torch.where(allowed_letters, probs, torch.zeros_like(probs))
            # torch.where(<your_tensor> != 0, <tensor with zeroz>, <tensor with the value>)
            actions_t = Categorical(probs=probs).sample()
            
            word_mask = word_mask & (self.game_voc_matrix[:, t - 1].unsqueeze(0) == actions_t.unsqueeze(1))

            # keep which words are acceptable
            cur_log_probs = torch.log(probs[range(batch_size), actions_t].clip(min=1e-12)).squeeze()

            # letters_allowed_count = allowed_letters.sum(axis=-1)
            # log_probs[letters_allowed_count > 1] += cur_log_probs[letters_allowed_count > 1]
            log_probs += cur_log_probs
            
            actions[:, t-1] = actions_t
            input = actions_t

        return {
            "actions": actions.cpu().numpy(),
            # "logits": logits,
            "log_probs": log_probs,
            "values": values,
        }
    
    def act(self, inputs):
        '''
        input:
            inputs - numpy array, (batch_size x sequences x sequence_length)
        output: dict containing keys ['actions', 'logits', 'log_probs', 'values']:
            'actions' - selected actions, numpy, (batch_size, sequence_length)
            'log_probs' - log probs of selected actions, tensor, (batch_size)
            'values' - critic estimations, tensor, (batch_size)
        '''
        inputs = torch.LongTensor(inputs)
        letter_tokens, state_tokens = inputs[:, 0, :], inputs[:, 1, :]
        outputs = self(letter_tokens, state_tokens)
        return outputs

In [20]:
# test get_allowed_letters
from wrappers import SequenceWrapper, ReshapeWrapper
from wrappers import nature_dqn_env

num_letters = 29

env = WordleEnv()
env = SequenceWrapper(env, sos_token=1)
env = ReshapeWrapper(env)
# env = nature_dqn_env(nenvs=2)

word_mask = torch.tensor([[True, False, False], [False, True, False]])

present_letters = get_allowed_letters(torch.from_numpy(env.game_voc_matrix), word_mask, 3).to(torch.long)
present_letters

tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0,
         0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0]])

In [21]:
env.tokenizer

<tokenizer.Tokenizer at 0x7fda5009ec10>

In [22]:
from tokenizer import Tokenizer

tokenizer = Tokenizer()
game_voc_matrix = torch.FloatTensor(env.game_voc_matrix)
agent = RNNAgent(len(tokenizer.index2letter), len(tokenizer.index2guess_state), 64, 32, len(tokenizer.index2letter), output_len=5, sos_token=1, game_voc_matrix=game_voc_matrix)

In [23]:
obs = env.reset()

letter_tokens = torch.LongTensor(obs[:, 0, :])
state_tokens = torch.LongTensor(obs[:, 1, :])

agent_output = agent(letter_tokens, state_tokens).values()
agent_output

dict_values([array([[18, 14,  3, 22,  7]]), tensor([-16.6782], grad_fn=<AddBackward0>), tensor([-0.0216], grad_fn=<AddBackward0>)])

In [24]:
from runners import EnvRunner

runner = EnvRunner(env, agent, nsteps=6)

trajectory = runner.get_next()
print(f"Trajectory keys: {trajectory.keys()}")
print(f"Trajectory rewards: {trajectory['rewards']}")

Trajectory keys: dict_keys(['actions', 'log_probs', 'values', 'observations', 'rewards', 'dones'])
Trajectory rewards: [array([0.]), array([0.]), array([0.]), array([10.]), array([10.]), array([0.4])]


In [25]:
nenvs = 1

# Sanity checks
# assert 'logits' in trajectory, "Not found: policy didn't provide logits"
assert 'log_probs' in trajectory, "Not found: policy didn't provide log_probs of selected actions"
assert 'values' in trajectory, "Not found: policy didn't provide critic estimations"
# assert trajectory['logits'][0].shape == (nenvs, n_actions), "logits wrong shape"
assert trajectory['log_probs'][0].shape == (nenvs,), "log_probs wrong shape"
assert trajectory['values'][0].shape == (nenvs,), "values wrong shape"

for key in trajectory.keys():
    assert len(trajectory[key]) == 6, \
    f"something went wrong: 6 steps should have been done, got trajectory of length {len(trajectory[key])} for '{key}'"


In [26]:
class ComputeValueTargets:
    def __init__(self, policy, gamma=0.99):
        self.policy = policy
        self.gamma = gamma

    def __call__(self, trajectory, latest_observation):
        '''
        This method should modify trajectory inplace by adding 
        an item with key 'value_targets' to it
        
        input:
            trajectory - dict from runner
            latest_observation - last state, numpy, (num_envs x channels x width x height)
        '''
        T = len(trajectory['rewards'])
        targets = [None] * T
        R = self.policy.act(latest_observation)['values']
        for t in range(T - 1, -1, -1):
            rewards = torch.FloatTensor(trajectory['rewards'][t]).to(DEVICE)
            dones = torch.LongTensor(trajectory['dones'][t]).to(DEVICE)
            R = rewards + (1 - dones) * self.gamma * R
            targets[t] = R
        trajectory['value_targets'] = targets


In [27]:
class MergeTimeBatch:
    """ Merges first two axes typically representing time and env batch. """
    def __call__(self, trajectory, latest_observation):
        trajectory['log_probs'] = torch.cat(trajectory['log_probs'], dim=0)
        trajectory['values'] = torch.cat(trajectory['values'], dim=0)        
        trajectory['value_targets'] = torch.cat(trajectory['value_targets'], dim=0)

In [28]:
runner = EnvRunner(env, agent, nsteps=6, transforms=[ComputeValueTargets(agent),
                                                      MergeTimeBatch()])
trajectory = runner.get_next()

In [29]:
from collections import defaultdict
from torch.nn.utils import clip_grad_norm_

class A2C:
    def __init__(self, policy, optimizer, value_loss_coef=0.25, entropy_coef=0.01, max_grad_norm=0.5):
        self.policy = policy
        self.optimizer = optimizer
        self.value_loss_coef = value_loss_coef
        self.entropy_coef = entropy_coef
        self.max_grad_norm = max_grad_norm
    
    def loss(self, trajectory, write):
        # compute all losses
        # do not forget to use weights for critic loss and entropy loss

        targets = trajectory['value_targets'].to(DEVICE).detach()
        values = trajectory['values'].to(DEVICE)
        log_probs = trajectory['log_probs'].to(DEVICE)
        value_loss = (targets - values).pow(2).mean()
        
        # TODO: recompute
        entropy_loss = (log_probs * torch.exp(log_probs)).mean()
        
        advantage = (targets - values).detach()
        policy_loss = -(log_probs * advantage).mean()
        
        
        # log all losses
        write('losses', {
            'policy loss': policy_loss,
            'critic loss': value_loss,
            'entropy loss': entropy_loss
        })
        
        # additional logs
        write('critic/advantage', advantage.mean())
        write('critic/values', {
            'value predictions': values.mean(),
            'value targets': targets.mean(),
        })
        
        # return scalar loss
        return policy_loss + self.value_loss_coef * value_loss + self.entropy_coef * entropy_loss               

    def train(self, runner):
        # collect trajectory using runner
        # compute loss and perform one step of gradient optimization
        # do not forget to clip gradients
        
        trajectory = runner.get_next()
        
        self.optimizer.zero_grad()
        loss = self.loss(trajectory, runner.write)
        loss.backward()
        grad_norm = clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
        self.optimizer.step()
        
        runner.write('gradient norm', grad_norm)


In [30]:
from wrappers import SequenceWrapper, ReshapeWrapper, TensorboardSummaries
from wrappers import nature_dqn_env

env = WordleEnv()
env = SequenceWrapper(env, sos_token=1)
env = ReshapeWrapper(env)
env = TensorboardSummaries(env, prefix='wordle')

[DEBUG/MainProcess] created semlock with handle 129
[DEBUG/MainProcess] created semlock with handle 131
[DEBUG/MainProcess] created semlock with handle 130


In [31]:
from torch.optim import RMSprop

nenvs = 1
nsteps = 10
total_steps = 10 ** 7

# env = nature_dqn_env("SpaceInvadersNoFrameskip-v4", nenvs=nenvs)
# n_actions = env.action_space.spaces[0].n
obs = env.reset()

# model = Model(obs.shape[1:], n_actions).to(DEVICE)
policy = RNNAgent(
    len(tokenizer.index2letter), 
    len(tokenizer.index2guess_state), 
    64, 32, 
    len(tokenizer.index2letter), 
    output_len=5, 
    sos_token=1, 
    game_voc_matrix=game_voc_matrix
)

runner = EnvRunner(env, policy, nsteps=nsteps, transforms=[ComputeValueTargets(policy),
                                                      MergeTimeBatch()])
optimizer = RMSprop(policy.parameters(), 7e-4)
a2c = A2C(policy, optimizer, max_grad_norm=1.0)

In [None]:
from tqdm import trange

obs = env.reset()
for step in trange(0, total_steps + 1, nenvs * nsteps):
    a2c.train(runner)

  0%|                                                                                                                                           | 0/1000001 [00:00<?, ?it/s][DEBUG/MainProcess] Queue._start_thread()
[DEBUG/MainProcess] doing self._thread.start()
[DEBUG/MainProcess] starting thread to feed data to pipe
[DEBUG/MainProcess] ... done self._thread.start()
[DEBUG/MainProcess] created semlock with handle 137
[DEBUG/MainProcess] created semlock with handle 138
[DEBUG/MainProcess] created semlock with handle 139
[DEBUG/MainProcess] Queue._start_thread()
[DEBUG/MainProcess] doing self._thread.start()
[DEBUG/MainProcess] starting thread to feed data to pipe
[DEBUG/MainProcess] ... done self._thread.start()
[DEBUG/MainProcess] created semlock with handle 145
[DEBUG/MainProcess] created semlock with handle 146
[DEBUG/MainProcess] created semlock with handle 147
[DEBUG/MainProcess] Queue._start_thread()
[DEBUG/MainProcess] doing self._thread.start()
[DEBUG/MainProcess] starting thread