In [1]:
%load_ext autoreload
%autoreload 2

import numpy as np
from wordle_env import WordleEnv
import torch
import torch.nn as nn
import torch.nn.functional as F




In [2]:
from model import RNNAgent, get_allowed_letters

In [3]:
torch.autograd.set_detect_anomaly(True)

<torch.autograd.anomaly_mode.set_detect_anomaly at 0x7f9a11d67e80>

In [4]:
DEVICE = torch.device('cpu')

In [5]:
from wrappers import SequenceWrapper, ReshapeWrapper
from wrappers import nature_dqn_env

num_letters = 29
env = nature_dqn_env(nenvs=4)
from tokenizer import Tokenizer

tokenizer = Tokenizer()

game_voc_matrix = torch.FloatTensor(env.game_voc_matrix)
agent = RNNAgent(len(tokenizer.index2letter), len(tokenizer.index2guess_state), 32, 64, len(tokenizer.index2letter), output_len=5, sos_token=1, game_voc_matrix=game_voc_matrix)






[INFO/Process-2] child process calling self.run()
[INFO/Process-1] child process calling self.run()
[INFO/Process-2] child process calling self.run()
[INFO/Process-1] child process calling self.run()
[INFO/Process-3] child process calling self.run()
[INFO/Process-3] child process calling self.run()
[INFO/Process-4] child process calling self.run()
[INFO/Process-4] child process calling self.run()
[DEBUG/MainProcess] created semlock with handle 92
[DEBUG/MainProcess] created semlock with handle 94
[DEBUG/MainProcess] created semlock with handle 96


In [6]:
from runners import EnvRunner

runner = EnvRunner(env, agent, nsteps=6)

trajectory = runner.get_next()
print(f"Trajectory keys: {trajectory.keys()}")
print(f"Trajectory rewards: {trajectory['rewards']}")
print(f"Trajectory values: {trajectory['values']}")

[DEBUG/MainProcess] Queue._start_thread()
[DEBUG/MainProcess] doing self._thread.start()
[DEBUG/MainProcess] starting thread to feed data to pipe
[DEBUG/MainProcess] ... done self._thread.start()


Trajectory keys: dict_keys(['actions', 'log_probs', 'values', 'observations', 'rewards', 'dones'])
Trajectory rewards: [array([0. , 0.4, 0. , 0.4]), array([0.4, 0. , 0. , 0.2]), array([0.4, 0.4, 0. , 0.2]), array([0. , 0.4, 0. , 0. ]), array([-0.4, -0.4, -0.2,  0.4]), array([ 0.8, -0.4, -0.4, -0.2])]
Trajectory values: [tensor([-0.0755, -0.0755, -0.0755, -0.0755], grad_fn=<SqueezeBackward0>), tensor([-0.0755, -0.0755, -0.0755, -0.0755], grad_fn=<SqueezeBackward0>), tensor([-0.0755, -0.0755, -0.0755, -0.0755], grad_fn=<SqueezeBackward0>), tensor([-0.0754, -0.0755, -0.0755, -0.0754], grad_fn=<SqueezeBackward0>), tensor([-0.0744, -0.0751, -0.0751, -0.0752], grad_fn=<SqueezeBackward0>), tensor([-0.0741, -0.0758, -0.0769, -0.0761], grad_fn=<SqueezeBackward0>)]


In [7]:
nenvs = 4

# Sanity checks
# assert 'logits' in trajectory, "Not found: policy didn't provide logits"
assert 'log_probs' in trajectory, "Not found: policy didn't provide log_probs of selected actions"
assert 'values' in trajectory, "Not found: policy didn't provide critic estimations"
# assert trajectory['logits'][0].shape == (nenvs, n_actions), "logits wrong shape"
assert trajectory['log_probs'][0].shape == (nenvs,), "log_probs wrong shape"
assert trajectory['values'][0].shape == (nenvs,), "values wrong shape"

for key in trajectory.keys():
    assert len(trajectory[key]) == 6, \
    f"something went wrong: 6 steps should have been done, got trajectory of length {len(trajectory[key])} for '{key}'"


In [8]:
trajectory['values'][0].shape

torch.Size([4])

In [9]:
class ComputeValueTargets:
    def __init__(self, policy, gamma=0.75):
        self.policy = policy
        self.gamma = gamma

    def __call__(self, trajectory, latest_observation):
        '''
        This method should modify trajectory inplace by adding 
        an item with key 'value_targets' to it
        
        input:
            trajectory - dict from runner
            latest_observation - last state, numpy, (num_envs x channels x width x height)
        '''
        T = len(trajectory['rewards'])
        targets = [None] * T
        R = self.policy.act(latest_observation)['values']
        for t in range(T - 1, -1, -1):
            rewards = torch.FloatTensor(trajectory['rewards'][t]).to(DEVICE)
            dones = torch.LongTensor(trajectory['dones'][t]).to(DEVICE)
            R = rewards + (1 - dones) * self.gamma * R
            targets[t] = R
        trajectory['value_targets'] = targets


In [10]:
class MergeTimeBatch:
    """ Merges first two axes typically representing time and env batch. """
    def __call__(self, trajectory, latest_observation):
        trajectory['log_probs'] = torch.cat(trajectory['log_probs'], dim=0)
        trajectory['values'] = torch.cat(trajectory['values'], dim=0)        
        trajectory['value_targets'] = torch.cat(trajectory['value_targets'], dim=0)

In [11]:
runner = EnvRunner(env, agent, nsteps=6, transforms=[ComputeValueTargets(agent, gamma=0.9),
                                                      MergeTimeBatch()])
trajectory = runner.get_next()

In [12]:
from collections import defaultdict
from torch.nn.utils import clip_grad_norm_

class A2C:
    def __init__(self, policy, optimizer, value_loss_coef=0.25, entropy_coef=0.01, max_grad_norm=0.5):
        self.policy = policy
        self.optimizer = optimizer
        self.value_loss_coef = value_loss_coef
        self.entropy_coef = entropy_coef
        self.max_grad_norm = max_grad_norm
    
    def loss(self, trajectory, write):
        # compute all losses
        # do not forget to use weights for critic loss and entropy loss

        targets = trajectory['value_targets'].to(DEVICE).detach()
        values = trajectory['values'].to(DEVICE)
        log_probs = trajectory['log_probs'].to(DEVICE)
        value_loss = (targets - values).pow(2).mean()
        
        # TODO: recompute
        entropy_loss = 0.0 # (log_probs * torch.exp(log_probs)).mean()
        
        advantage = (targets - values).detach()
        policy_loss = -(log_probs * advantage).mean()
        
        
        # log all losses
        write('losses', {
            'policy loss': policy_loss,
            'critic loss': value_loss,
            'entropy loss': entropy_loss
        })
        
        # additional logs
        write('critic/advantage', advantage.mean())
        write('critic/values', {
            'value predictions': values.mean(),
            'value targets': targets.mean(),
        })
        
        # return scalar loss
        return policy_loss + self.value_loss_coef * value_loss + self.entropy_coef * entropy_loss               

    def train(self, runner):
        # collect trajectory using runner
        # compute loss and perform one step of gradient optimization
        # do not forget to clip gradients
        
        trajectory = runner.get_next()
        
        self.optimizer.zero_grad()
        loss = self.loss(trajectory, runner.write)
        loss.backward()
        grad_norm = clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
        self.optimizer.step()
        
        runner.write('gradient norm', grad_norm)


In [13]:
env.close()

[INFO/Process-2] process shutting down
[INFO/Process-2] process shutting down
[INFO/Process-4] process shutting down
[INFO/Process-4] process shutting down
[DEBUG/Process-2] running all "atexit" finalizers with priority >= 0
[INFO/Process-1] process shutting down
[DEBUG/Process-2] running all "atexit" finalizers with priority >= 0
[INFO/Process-1] process shutting down
[DEBUG/Process-4] running all "atexit" finalizers with priority >= 0
[DEBUG/Process-4] running all "atexit" finalizers with priority >= 0
[DEBUG/Process-2] running the remaining "atexit" finalizers
[DEBUG/Process-2] running the remaining "atexit" finalizers
[DEBUG/Process-1] running all "atexit" finalizers with priority >= 0
[DEBUG/Process-4] running the remaining "atexit" finalizers
[DEBUG/Process-1] running all "atexit" finalizers with priority >= 0
[DEBUG/Process-4] running the remaining "atexit" finalizers
[DEBUG/Process-1] running the remaining "atexit" finalizers
[DEBUG/Process-1] running the remaining "atexit" fin

In [14]:
from wordle_env import WordleEnv
from wrappers import SequenceWrapper, ReshapeWrapper, TensorboardSummaries
from wrappers import nature_dqn_env

# env = WordleEnv()
# env = SequenceWrapper(env, sos_token=1)
# env = ReshapeWrapper(env)
# env = TensorboardSummaries(env, prefix='wordle')

env = nature_dqn_env(nenvs=4)





[INFO/Process-6] child process calling self.run()
[INFO/Process-6] child process calling self.run()
[INFO/Process-5] child process calling self.run()
[INFO/Process-5] child process calling self.run()
[INFO/Process-8] child process calling self.run()
[INFO/Process-8] child process calling self.run()
[INFO/Process-7] child process calling self.run()
[INFO/Process-7] child process calling self.run()
[DEBUG/MainProcess] created semlock with handle 105
[DEBUG/MainProcess] created semlock with handle 107
[DEBUG/MainProcess] created semlock with handle 114


In [15]:
from torch.optim import RMSprop

nenvs = 4
nsteps = 10
total_steps = 10 ** 6

# env = nature_dqn_env("SpaceInvadersNoFrameskip-v4", nenvs=nenvs)
# n_actions = env.action_space.spaces[0].n
obs = env.reset()

# model = Model(obs.shape[1:], n_actions).to(DEVICE)
policy = RNNAgent(
    len(tokenizer.index2letter), 
    len(tokenizer.index2guess_state), 
    32, 128, 
    len(tokenizer.index2letter), 
    output_len=5, 
    sos_token=1, 
    game_voc_matrix=game_voc_matrix
)

runner = EnvRunner(env, policy, nsteps=nsteps, transforms=[ComputeValueTargets(policy),
                                                      MergeTimeBatch()])
optimizer = RMSprop(policy.parameters(), 7e-4)
a2c = A2C(policy, optimizer, max_grad_norm=1.0)

[Level 5/MainProcess] finalizer calling <function close_fds at 0x7f9a12b9ed30> with args [90, 95] and kwargs {}
[Level 5/MainProcess] finalizer calling <function close_fds at 0x7f9a12b9ed30> with args [88, 93] and kwargs {}
[Level 5/MainProcess] finalizer calling <function close_fds at 0x7f9a12b9ed30> with args [87, 91] and kwargs {}
[Level 5/MainProcess] finalizer calling <function close_fds at 0x7f9a12b9ed30> with args [85, 89] and kwargs {}


In [16]:
from tqdm import trange

obs = env.reset()
for step in trange(0, total_steps + 1, nenvs * nsteps):
    if step % 1000 == 0 and step > 0:
        torch.save(a2c.policy.state_dict(), f"model_weights/step_{step}")

    a2c.train(runner)

[DEBUG/MainProcess] created semlock with handle 77
  0%|                                                                           | 0/25001 [00:00<?, ?it/s][DEBUG/MainProcess] Queue._start_thread()
[DEBUG/MainProcess] doing self._thread.start()
[DEBUG/MainProcess] starting thread to feed data to pipe
[DEBUG/MainProcess] ... done self._thread.start()
[DEBUG/MainProcess] created semlock with handle 87
[DEBUG/MainProcess] created semlock with handle 88
[DEBUG/MainProcess] created semlock with handle 89
[DEBUG/MainProcess] Queue._start_thread()
[DEBUG/MainProcess] doing self._thread.start()
[DEBUG/MainProcess] starting thread to feed data to pipe
[DEBUG/MainProcess] ... done self._thread.start()
[DEBUG/MainProcess] created semlock with handle 119
[DEBUG/MainProcess] created semlock with handle 120
[DEBUG/MainProcess] created semlock with handle 121
[DEBUG/MainProcess] Queue._start_thread()
[DEBUG/MainProcess] doing self._thread.start()
[DEBUG/MainProcess] starting thread to feed data to p

KeyboardInterrupt: 

In [None]:
torch.save(policy.state_dict(), 'model_weights/better_policy.pth')

In [None]:
obs = env.reset()

In [17]:
# test get_allowed_letters
from wrappers import SequenceWrapper, ReshapeWrapper
from wrappers import nature_dqn_env

num_letters = 29

env = WordleEnv()
env = SequenceWrapper(env, sos_token=1)
env = ReshapeWrapper(env)

In [18]:
def transform2word(word_vector):
    letter_list = list(map(lambda x: env.tokenizer.index2letter[x], word_vector))
    return ''.join(letter_list)

In [19]:
policy = RNNAgent(
    len(tokenizer.index2letter), 
    len(tokenizer.index2guess_state), 
    64, 32, 
    len(tokenizer.index2letter), 
    output_len=5, 
    sos_token=1, 
    game_voc_matrix=game_voc_matrix
)

In [29]:
tokenizer.guess_state2index

{'<PAD>': 0,
 '<SOS>': 1,
 '<EOS>': 2,
 '<RIGHT>': 3,
 '<CONTAINED>': 4,
 '<MISS>': 5}

In [30]:
obs = env.reset()

env_word = env.word
print(f"Real word: {transform2word(env.word)}")

for _ in range(6):
    action = a2c.policy.act(obs)['actions'].squeeze()
    obs, rew, done, info = env.step(action)
    print((obs[:, 1, :] == 3).sum())
    print(obs)
    print(f"guess: {transform2word(action)}, reward: {rew[0]}")

Real word: snail
2
[[[ 1 21 22  3 21 10  1  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
    0  0  0  0  0  0  0  0  0  0  0  0  0]
  [ 1  3  5  3  5  5  1  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
    0  0  0  0  0  0  0  0  0  0  0  0  0]]]
guess: stash, reward: 0.8
3
[[[ 1 21 22  3 21 10  1 21 14 17 17 18  1  0  0  0  0  0  0  0  0  0  0
    0  0  0  0  0  0  0  0  0  0  0  0  0]
  [ 1  3  5  3  5  5  1  3  5  5  5  4  1  0  0  0  0  0  0  0  0  0  0
    0  0  0  0  0  0  0  0  0  0  0  0  0]]]
guess: sloop, reward: 0.4
5
[[[ 1 21 22  3 21 10  1 21 14 17 17 18  1 21 16 11 20 22  1  0  0  0  0
    0  0  0  0  0  0  0  0  0  0  0  0  0]
  [ 1  3  5  3  5  5  1  3  5  5  5  4  1  3  3  5  4  5  1  0  0  0  0
    0  0  0  0  0  0  0  0  0  0  0  0  0]]]
guess: snirt, reward: 0.8
7
[[[ 1 21 22  3 21 10  1 21 14 17 17 18  1 21 16 11 20 22  1 21 16 11  9
   21  1  0  0  0  0  0  0  0  0  0  0  0]
  [ 1  3  5  3  5  5  1  3  5  5  5  4  1  3  3  5  4  5  1  3  3  5  4
    5  1  0  0

In [25]:
obs

array([[[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]])

In [21]:
obs = env.reset()
env_word = env.word
print(f"Real word: {transform2word(env.word)}")

Real word: inlet


In [None]:
action = np.array([6, 6, v, 8, 8])
obs, rew, done, info = env.step(action)
print(f"guess: {transform2word(action)}, reward: {rew[0]}")

In [None]:
obs

In [None]:
action

In [None]:
action == env_word

In [None]:
obs = env.reset()
done = False

print(f"True word: {transform2word(env.word)}")

print("guesses:")
print("--------")

while not done:
    action = agent.act(obs)['actions'].squeeze()
    obs, rew, done, info = env.step(action)
    print(f"{transform2word(action)} (reward = {rew})")

In [None]:
GAME_VOCABULARY