In [22]:
import os

# Initial packages
import numpy as np
import random
import string
from typing import Optional
from tqdm import tqdm
from einops import reduce

# ML packages for neural-network
import torch as T
import torch.optim as optim
import torch.nn as nn
from torch.distributions.categorical import Categorical
from torch import einsum
import gymnasium as gym

# T.autograd.set_detect_anomaly(True)

In [23]:
def load_dictionary(file_path):

    # Read in the .txt file
    with open(file_path, 'r') as file:
        words = file.readlines()

    # Remove newline characters and whitespace
    words_list = []
    for word in words:
        temp_word = word.strip()
        if len(temp_word) > 1:
            words_list.append(temp_word)

    trimmed_words_list = [word for word in words_list if (len(word) == 9)]

    # Randomly shuffle the list
    random.seed(2024)
    random.shuffle(trimmed_words_list)

    return trimmed_words_list

In [24]:
# Masking Function provided through: https://boring-guy.sh/posts/masking-rl/

class CategoricalMasked(Categorical):
    def __init__(self, logits: T.Tensor, mask: Optional[T.Tensor] = None):
        self.mask = mask
        self.batch, self.nb_action = logits.size()
        if mask is None:
            super(CategoricalMasked, self).__init__(logits=logits)
        else:
            self.mask_value = T.tensor(
                T.finfo(logits.dtype).min, dtype=logits.dtype
            )
            logits = T.where(self.mask, logits, self.mask_value)
            super(CategoricalMasked, self).__init__(logits=logits)

    def entropy(self):
        if self.mask is None:
            return super().entropy()
        # Elementwise multiplication
        p_log_p = einsum("ij,ij->ij", self.logits, self.probs)
        # Compute the entropy with possible action only
        p_log_p = T.where(
            self.mask,
            p_log_p,
            T.tensor(0, dtype=p_log_p.dtype, device=p_log_p.device),
        )
        return -reduce(p_log_p, "b a -> b", "sum", b=self.batch, a=self.nb_action)

In [25]:
class HangmanEnv(gym.Env):
    def __init__(self, dictionary, total_lives):
        super(HangmanEnv, self).__init__()

        # Lives for the game
        self.total_lives = total_lives

        # The dictionary will not change, so save it here
        self.dictionary = dictionary
        self.valid_indices = [index for index in range(len(self.dictionary))]
        self.incorrect_indices = []

        # Action space involves choosing ['a','b',...,'y','z'] --> [0,1,...,24,25]
        self.action_space = gym.spaces.Discrete(26)

        # Observation (i.e. state space) the one-hot encoding of the current word along with the information about currently guessed letters
        self.observation_space = gym.spaces.Box(low=-1, high=1, shape=(837,), dtype=np.int8)

    def step(self, action):

        # Turn the action into a letter
        current_guess = chr(action + ord('a'))

        self.guessed_letters.append(current_guess)
        self.available_letters.remove(current_guess)

        # If the current guess is in the word, append the guess to the word state
        if current_guess in self.guess_word:
            correct_indices = np.where(self.guess_word == current_guess)[0]
            self.current_word_state[correct_indices] = current_guess
            self.reward = len(correct_indices) / len(self.guess_word)
            # self.reward = 0

        # Not in the word, then remove a life
        else:
            correct_indices = []
            self.lives_remaining -= 1
            self.reward = 0

        # Create the current state vector
        if str(self.current_word_state) == str(self.guess_word):
            self.status = 1
            self.reward = self.lives_remaining * 10
            self.done = True
        elif self.lives_remaining == 0:
            self.status = 0
            self.done = True

        # Update the current letter guessed
        self.one_hot_state[-1, action] = -1
        for index in correct_indices:
            self.one_hot_state[index, action] = 1
            self.one_hot_state[index, -1] = 0
        self.observation = self.one_hot_state.flatten()

        info = {}
        truncated = False
        return self.observation, self.reward, self.done, truncated, info

    def reset(self, seed=None, options=None):

        # Initialize our environment here
        self.lives_remaining = self.total_lives
        self.available_letters = list(string.ascii_lowercase)
        self.guessed_letters = []
        self.done = False
        self.status = 0
        self.reward = 0

        # Now draw a word from either the total or incorrect word replay dictionaries
        if len(self.incorrect_indices) > 0:
            dict_choice = np.random.choice([0,1])

            # If 0, sample from the unsearched words in the full dictionary
            if dict_choice == 0:
                self.guess_index = np.random.choice(self.valid_indices)
            else:
                self.guess_index = np.random.choice(self.incorrect_indices)
        else:
            self.guess_index = np.random.choice(self.valid_indices)

        self.guess_word = np.array(list(self.dictionary[self.guess_index]))
        self.current_word_state = np.array(['_'] * len(self.guess_word))

        # Initialize the current state vector
        self.one_hot_state = np.zeros([31, 27], dtype=np.int8)
        for i in range(len(self.guess_word)):
            self.one_hot_state[i, -1] = -1
            self.one_hot_state[-1, -1] = -1
        self.observation = self.one_hot_state.flatten()

        info = {}
        return self.observation, info

    def render(self, action):

        result = ' '.join([str(elem) for elem in self.current_word_state])

        print(f'Guessed {chr(action + ord("a"))}. The current word state is {result} and with {self.lives_remaining} lives remaining.')

In [26]:
# PPO classes originally defined here: https://github.com/philtabor/Youtube-Code-Repository/tree/master/ReinforcementLearning/PolicyGradient/PPO/torch

class PPOMemory:
    def __init__(self, batch_size):
        self.states = []
        self.probs = []
        self.vals = []
        self.actions = []
        self.rewards = []
        self.dones = []

        self.batch_size = batch_size

    def generate_batches(self):
        n_states = len(self.states)
        batch_start = np.arange(0, n_states, self.batch_size)
        indices = np.arange(n_states, dtype=np.int64)
        np.random.shuffle(indices)
        batches = [indices[i:i+self.batch_size] for i in batch_start]

        return np.array(self.states),\
                np.array(self.actions),\
                np.array(self.probs),\
                np.array(self.vals),\
                np.array(self.rewards),\
                np.array(self.dones),\
                batches

    def store_memory(self, state, action, probs, vals, reward, done):
        self.states.append(state)
        self.actions.append(action)
        self.probs.append(probs)
        self.vals.append(vals)
        self.rewards.append(reward)
        self.dones.append(done)

    def clear_memory(self):
        self.states = []
        self.probs = []
        self.actions = []
        self.rewards = []
        self.dones = []
        self.vals = []

class ActorNetwork(nn.Module):
    def __init__(self, n_actions, input_dims, alpha,
            fc1_dims=256, fc2_dims=256, chkpt_dir='tmp/ppo'):
        super(ActorNetwork, self).__init__()

        self.checkpoint_file = os.path.join(chkpt_dir, 'actor_torch_ppo')
        self.actor = nn.Sequential(
                nn.Linear(*input_dims, fc1_dims),
                nn.ReLU(),
                nn.Linear(fc1_dims, fc2_dims),
                nn.ReLU(),
                nn.Linear(fc2_dims, n_actions),
                nn.Softmax(dim=-1)
        )

        self.optimizer = optim.Adam(self.parameters(), lr=alpha)
        self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu')
        self.to(self.device)

    def forward(self, state):
        dist = self.actor(state)
        dist = Categorical(dist)

        return dist

    def save_checkpoint(self):
        T.save(self.state_dict(), self.checkpoint_file)

    def load_checkpoint(self):
        self.load_state_dict(T.load(self.checkpoint_file))

class CriticNetwork(nn.Module):
    def __init__(self, input_dims, alpha, fc1_dims=256, fc2_dims=256,
            chkpt_dir='tmp/ppo'):
        super(CriticNetwork, self).__init__()

        self.checkpoint_file = os.path.join(chkpt_dir, 'critic_torch_ppo')
        self.critic = nn.Sequential(
                nn.Linear(*input_dims, fc1_dims),
                nn.ReLU(),
                nn.Linear(fc1_dims, fc2_dims),
                nn.ReLU(),
                nn.Linear(fc2_dims, 1)
        )

        self.optimizer = optim.Adam(self.parameters(), lr=alpha)
        self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu')
        self.to(self.device)

    def forward(self, state):
        value = self.critic(state)

        return value

    def save_checkpoint(self):
        T.save(self.state_dict(), self.checkpoint_file)

    def load_checkpoint(self):
        self.load_state_dict(T.load(self.checkpoint_file))

class Agent:
    def __init__(self, n_actions, input_dims, gamma=0.99, alpha=0.0003, gae_lambda=0.95,
            policy_clip=0.1, batch_size=64, n_epochs=10):
        self.gamma = gamma
        self.policy_clip = policy_clip
        self.n_epochs = n_epochs
        self.gae_lambda = gae_lambda

        self.actor = ActorNetwork(n_actions, input_dims, alpha)
        self.critic = CriticNetwork(input_dims, alpha)
        self.memory = PPOMemory(batch_size)

    def remember(self, state, action, probs, vals, reward, done):
        self.memory.store_memory(state, action, probs, vals, reward, done)

    def save_models(self):
        print('... saving models ...')
        self.actor.save_checkpoint()
        self.critic.save_checkpoint()

    def load_models(self):
        print('... loading models ...')
        self.actor.load_checkpoint()
        self.critic.load_checkpoint()

    def choose_action(self, observation, viable_actions):
        state = T.tensor([observation], dtype=T.float).to(self.actor.device)

        # Calculate distribution over actions via the actor network
        dist = self.actor(state)

        # Modify the categorical distribution based on the available actions
        chosen_letters_indices = np.where(viable_actions == -1)[0]
        if len(chosen_letters_indices) > 0:
            mask = T.ones(dist.logits.shape, dtype=T.bool) # batch size, nb action
            mask[0, chosen_letters_indices] = False
            masked_dist = CategoricalMasked(logits=dist.logits, mask=mask)

            # Sample the conditional distribution
            action = masked_dist.sample()
            probs = T.squeeze(masked_dist.log_prob(action)).item()
        else:
            action = dist.sample()
            probs = T.squeeze(dist.log_prob(action)).item()

        # Determine the value of the guesses state via the critic network
        value = self.critic(state)
        value = T.squeeze(value).item()

        action = T.squeeze(action).item()

        return action, probs, value

    def learn(self):

        for _ in range(self.n_epochs):
            state_arr, action_arr, old_prob_arr, vals_arr,\
            reward_arr, dones_arr, batches = \
                    self.memory.generate_batches()

            values = vals_arr
            advantage = np.zeros(len(reward_arr), dtype=np.float32)

            for t in range(len(reward_arr)-1):
                discount = 1
                a_t = 0
                for k in range(t, len(reward_arr)-1):
                    a_t += discount*(reward_arr[k] + self.gamma*values[k+1]*\
                            (1-int(dones_arr[k])) - values[k])
                    discount *= self.gamma*self.gae_lambda
                advantage[t] = a_t
            advantage = T.tensor(advantage).to(self.actor.device)

            values = T.tensor(values).to(self.actor.device)
            for batch in batches:
                states = T.tensor(state_arr[batch], dtype=T.float).to(self.actor.device)
                old_probs = T.tensor(old_prob_arr[batch]).to(self.actor.device)
                actions = T.tensor(action_arr[batch]).to(self.actor.device)

                dist = self.actor(states)
                critic_value = self.critic(states)

                critic_value = T.squeeze(critic_value)

                new_probs = dist.log_prob(actions)
                prob_ratio = new_probs.exp() / old_probs.exp()
                #prob_ratio = (new_probs - old_probs).exp()
                weighted_probs = advantage[batch] * prob_ratio
                weighted_clipped_probs = T.clamp(prob_ratio, 1-self.policy_clip,
                        1+self.policy_clip)*advantage[batch]
                actor_loss = -T.min(weighted_probs, weighted_clipped_probs).mean()

                returns = advantage[batch] + values[batch]
                critic_loss = (returns-critic_value)**2
                critic_loss = critic_loss.mean()

                total_loss = actor_loss + 0.5*critic_loss
                self.actor.optimizer.zero_grad()
                self.critic.optimizer.zero_grad()
                total_loss.backward()
                self.actor.optimizer.step()
                self.critic.optimizer.step()

        self.memory.clear_memory()

In [27]:
dictionary_path = 'words_250000_train.txt'
words_list = load_dictionary(dictionary_path)

print(f'Loaded in dictionary with {len(words_list)} words.')

Loaded in dictionary with 30906 words.


In [None]:
env = HangmanEnv(dictionary=words_list, total_lives=6)
N = 2048
batch_size = 64
epochs = 4
alpha = 0.0003
agent = Agent(n_actions=env.action_space.n, batch_size=batch_size, alpha=alpha, n_epochs=epochs, input_dims=env.observation_space.shape)
# agent.actor.load_checkpoint()
# agent.critic.load_checkpoint()

n_games = len(env.dictionary) * 2
n_steps = 0
learn_iters = 0
wins_list = []

progress_bar = tqdm(range(n_games), desc='Sweeping through all words', leave=False)
for i in progress_bar:
    observation, info = env.reset()
    done = False
    score = 0
    while not done:
        action, prob, val = agent.choose_action(observation, env.one_hot_state[-1, :-1])
        observation_, reward, done, truncated, info = env.step(action)
        # env.render(action)
        if (n_steps+1) % N == 0:
            agent.learn()
            learn_iters += 1
        n_steps += 1
        score += reward
        agent.remember(observation, action, prob, val, reward, done)
        observation = observation_

    # Update the word index lists
    if env.guess_index in env.incorrect_indices:
       if env.status == 1:
           env.incorrect_indices.remove(env.guess_index)
    else:
        env.valid_indices.remove(env.guess_index)
        if env.status == 0:
            env.incorrect_indices.append(env.guess_index)

    # Tabulate wins and win rate
    wins_list.append(env.status)
    win_rate = np.round(np.mean(wins_list[-1000:]), 4) * 100

    progress_bar.set_postfix_str(s=f'Win Rate/1000: {win_rate}%.')

In [14]:
agent.actor.save_checkpoint()
agent.critic.save_checkpoint()