In [1]:
import env
import random
import numpy as np

In [2]:
with open('valid-words.csv') as f:
    valid_words = f.readlines()
valid_words = [w.strip() for w in valid_words]

with open('word-bank.csv') as f:
    word_bank = f.readlines()
word_bank = [w.strip() for w in word_bank]

In [3]:
e = env.WordleEnv1(valid_words, word_bank)

In [20]:
e.reset()

array([[0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0],
       ...,
       [0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0]])

In [21]:
e.step(432)

(array([[ 1,  1,  0, -1,  0],
        [ 1,  1,  0,  0,  0],
        [ 1,  1,  0,  0,  0],
        ...,
        [ 0,  0,  0,  0,  1],
        [ 0,  0,  0, -1,  0],
        [ 0,  0,  0,  0,  0]]),
 -1,
 False,
 {'attempt': 1,
  'guesses': [['a', 'n', 't', 'a', 'e'],
   ['', '', '', '', ''],
   ['', '', '', '', ''],
   ['', '', '', '', ''],
   ['', '', '', '', ''],
   ['', '', '', '', '']],
  'known_letters': ['', '', '', 'a', ''],
  'present_letters': ['n', 'a'],
  'absent_letters': ['e', 't']})

# Simple Argmax Agent

In [6]:
def simple_agent(observation):
    if observation.sum() == 0:
        return random.randint(0, observation.shape[0])
    return observation.sum(axis=1).argmax()

In [7]:
%%time
rewards = []
attempts = []
for _ in range(100):
    done = False
    observation = e.reset()
    while not done:
        action = simple_agent(observation)
        # print(e.valid_words[action])
        observation, reward, done, info = e.step(action)
    rewards.append(reward)
    attempts.append(info["attempt"])
    # print(f'scored {reward} in {info["attempt"]} attempt(s)')
print(f'average score: {np.mean(rewards)}')
print(f'average attempts: {np.mean(attempts)}')

average score: 2.19
average attempts: 5.75
CPU times: user 2.38 s, sys: 7.33 ms, total: 2.39 s
Wall time: 2.38 s


# Stable Baselines

In [8]:
from stable_baselines3 import A2C

In [12]:
%%time
model = A2C('MlpPolicy', e, verbose=0)
model.learn(total_timesteps=5000)

CPU times: user 1min 57s, sys: 1.34 s, total: 1min 59s
Wall time: 1min 9s


<stable_baselines3.a2c.a2c.A2C at 0x7fed40b40950>

In [19]:
%%time
rewards = []
attempts = []
for _ in range(1):
    done = False
    observation = e.reset()
    while not done:
        action, _state = model.predict(obs, deterministic=True)
        print(e.valid_words[action])
        observation, reward, done, info = e.step(action)
    rewards.append(reward)
    attempts.append(info["attempt"])
    print(f'scored {reward} in {info["attempt"]} attempt(s)')
print(f'average score: {np.mean(rewards)}')
print(f'average attempts: {np.mean(attempts)}')

average score: 1.31
average attempts: 5.77
CPU times: user 2.34 s, sys: 10.9 ms, total: 2.35 s
Wall time: 2.35 s


# Manual Analysis

In [32]:
import string
import pandas as pd
from tabulate import tabulate

In [6]:
wbm = np.array([list(word) for word in word_bank])
vwm = np.array([list(word) for word in valid_words])

## Letter

In [51]:
stats = []
for l in string.ascii_lowercase:
    lm = wbm == l
    stats.append({
        'letter': l,
        'words_contain': lm.max(axis=1).sum(),
        'words_contain_percent': lm.max(axis=1).mean(),
        'freq': lm.sum(axis=1).mean(),
        'freq_if_present': lm.sum(axis=1)[lm.sum(axis=1) > 0].mean(),
        'max_freq': lm.sum(axis=1).max(),
    })

In [91]:
print(tabulate(pd.DataFrame(stats).set_index('letter'), tablefmt="pipe", headers="keys"))

| letter   |   words_contain |   words_contain_percent |      freq |   freq_if_present |   max_freq |
|:---------|----------------:|------------------------:|----------:|------------------:|-----------:|
| a        |             909 |               0.392657  | 0.422894  |           1.07701 |          2 |
| b        |             267 |               0.115335  | 0.121382  |           1.05243 |          3 |
| c        |             448 |               0.193521  | 0.206048  |           1.06473 |          2 |
| d        |             370 |               0.159827  | 0.169762  |           1.06216 |          3 |
| e        |            1056 |               0.456156  | 0.532613  |           1.16761 |          3 |
| f        |             207 |               0.0894168 | 0.0993521 |           1.11111 |          3 |
| g        |             300 |               0.12959   | 0.134341  |           1.03667 |          2 |
| h        |             379 |               0.163715  | 0.168035  |           1.0

In [55]:
contain_rate= {}
for l in string.ascii_lowercase:
    lm = wbm == l
    contain_rate[l] = lm.max(axis=1).mean()

In [75]:
scores = []
for word in valid_words:
    score = 0
    for l in set(list(word)):
        score += contain_rate[l]
    scores.append(score)    

In [76]:
df = pd.DataFrame({'word': valid_words, 'score': scores})

In [77]:
df['in_bank'] = df['word'].isin(word_bank)

In [93]:
print(tabulate(df[df['in_bank']].sort_values(by='score', ascending=False).head(20).set_index('word'), tablefmt="pipe", headers="keys"))


| word   |   score | in_bank   |
|:-------|--------:|:----------|
| alter  | 1.7784  | True      |
| later  | 1.7784  | True      |
| alert  | 1.7784  | True      |
| irate  | 1.77797 | True      |
| arose  | 1.76803 | True      |
| stare  | 1.76544 | True      |
| raise  | 1.7568  | True      |
| arise  | 1.7568  | True      |
| learn  | 1.72786 | True      |
| renal  | 1.72786 | True      |
| saner  | 1.7149  | True      |
| snare  | 1.7149  | True      |
| cater  | 1.69201 | True      |
| trace  | 1.69201 | True      |
| react  | 1.69201 | True      |
| crate  | 1.69201 | True      |
| stale  | 1.6838  | True      |
| steal  | 1.6838  | True      |
| least  | 1.6838  | True      |
| slate  | 1.6838  | True      |


In [92]:
print(tabulate(df.sort_values(by='score', ascending=False).head(20).set_index('word'), tablefmt="pipe", headers="keys"))


| word   |   score | in_bank   |
|:-------|--------:|:----------|
| oater  | 1.7892  | False     |
| roate  | 1.7892  | False     |
| orate  | 1.7892  | False     |
| realo  | 1.78099 | False     |
| taler  | 1.7784  | False     |
| artel  | 1.7784  | False     |
| ratel  | 1.7784  | False     |
| alert  | 1.7784  | True      |
| alter  | 1.7784  | True      |
| later  | 1.7784  | True      |
| terai  | 1.77797 | False     |
| irate  | 1.77797 | True      |
| retia  | 1.77797 | False     |
| raile  | 1.76976 | False     |
| ariel  | 1.76976 | False     |
| arose  | 1.76803 | True      |
| aeros  | 1.76803 | False     |
| soare  | 1.76803 | False     |
| taser  | 1.76544 | False     |
| strae  | 1.76544 | False     |


In [99]:
print(tabulate(df[(df['in_bank']) & (~df.word.str.contains('[alter]', regex=True))].sort_values(by='score', ascending=False).head(20).set_index('word'), tablefmt="pipe", headers="keys"))


| word   |   score | in_bank   |
|:-------|--------:|:----------|
| sonic  | 1.26825 | True      |
| scion  | 1.26825 | True      |
| noisy  | 1.25486 | True      |
| disco  | 1.1905  | True      |
| bison  | 1.19006 | True      |
| sound  | 1.15248 | True      |
| synod  | 1.13521 | True      |
| shiny  | 1.12786 | True      |
| spiny  | 1.11361 | True      |
| suing  | 1.11102 | True      |
| using  | 1.11102 | True      |
| minus  | 1.11015 | True      |
| bonus  | 1.10799 | True      |
| doing  | 1.09719 | True      |
| dingo  | 1.09719 | True      |
| spicy  | 1.06955 | True      |
| music  | 1.06609 | True      |
| snowy  | 1.05918 | True      |
| bingo  | 1.0527  | True      |
| hound  | 1.04924 | True      |


In [100]:
print(tabulate(df[(~df.word.str.contains('[oater]', regex=True))].sort_values(by='score', ascending=False).head(20).set_index('word'), tablefmt="pipe", headers="keys"))


| word   |   score | in_bank   |
|:-------|--------:|:----------|
| lysin  | 1.24406 | False     |
| linds  | 1.22376 | False     |
| sulci  | 1.21728 | False     |
| sling  | 1.19352 | True      |
| lings  | 1.19352 | False     |
| limns  | 1.19266 | False     |
| hilus  | 1.18747 | False     |
| blins  | 1.17927 | False     |
| incus  | 1.17495 | False     |
| pilus  | 1.17322 | False     |
| pulis  | 1.17322 | False     |
| shily  | 1.17019 | False     |
| clips  | 1.16933 | False     |
| idyls  | 1.16631 | False     |
| unlid  | 1.15421 | False     |
| linch  | 1.15421 | False     |
| gusli  | 1.15335 | False     |
| iglus  | 1.15335 | False     |
| muils  | 1.15248 | False     |
| simul  | 1.15248 | False     |


## Word

In [63]:
stats =[]
for word in vwm:
    present_letter_msk = np.isin(wbm, word)
    known_letter_msk = wbm[:, :] == word

    score = present_letter_msk * 1
    score[known_letter_msk] = 1
    stats.append({
        'avg_score': score.sum(axis=1).mean(),
    })

In [64]:
df = pd.DataFrame(stats)
df['word'] = valid_words
df['in_bank'] = df['word'].isin(word_bank)

In [65]:
print(tabulate(df.sort_values(by='avg_score', ascending=False).head(20).set_index('word'), tablefmt="pipe", headers="keys"))

| word   |   avg_score | in_bank   |
|:-------|------------:|:----------|
| roate  |     1.98445 | False     |
| oater  |     1.98445 | False     |
| orate  |     1.98445 | False     |
| realo  |     1.98013 | False     |
| artel  |     1.96933 | False     |
| later  |     1.96933 | True      |
| taler  |     1.96933 | False     |
| ratel  |     1.96933 | False     |
| alter  |     1.96933 | True      |
| alert  |     1.96933 | True      |
| soare  |     1.95853 | False     |
| aeros  |     1.95853 | False     |
| arose  |     1.95853 | True      |
| irate  |     1.9486  | True      |
| retia  |     1.9486  | False     |
| terai  |     1.9486  | False     |
| reast  |     1.94773 | False     |
| stare  |     1.94773 | True      |
| strae  |     1.94773 | False     |
| teras  |     1.94773 | False     |


In [66]:
print(tabulate(df[df["in_bank"]].sort_values(by='avg_score', ascending=False).head(20).set_index('word'), tablefmt="pipe", headers="keys"))

| word   |   avg_score | in_bank   |
|:-------|------------:|:----------|
| alter  |     1.96933 | True      |
| alert  |     1.96933 | True      |
| later  |     1.96933 | True      |
| arose  |     1.95853 | True      |
| irate  |     1.9486  | True      |
| stare  |     1.94773 | True      |
| raise  |     1.92268 | True      |
| arise  |     1.92268 | True      |
| learn  |     1.90281 | True      |
| renal  |     1.90281 | True      |
| snare  |     1.88121 | True      |
| saner  |     1.88121 | True      |
| least  |     1.86998 | True      |
| stale  |     1.86998 | True      |
| slate  |     1.86998 | True      |
| steal  |     1.86998 | True      |
| react  |     1.86479 | True      |
| crate  |     1.86479 | True      |
| trace  |     1.86479 | True      |
| cater  |     1.86479 | True      |


In [50]:
print(tabulate(pd.concat([pd.DataFrame(wbm)[i].value_counts(normalize=True).to_frame().reset_index() for i in range(5)], axis=1), tablefmt="pipe", headers="keys"))


|    | index   |            0 | index   |           1 | index   |           2 | index   |             3 | index   |             4 |
|---:|:--------|-------------:|:--------|------------:|:--------|------------:|:--------|--------------:|:--------|--------------:|
|  0 | s       |   0.158099   | a       | 0.131317    | a       | 0.132613    | e       |   0.137365    | e       |   0.183153    |
|  1 | c       |   0.0855292  | o       | 0.120518    | i       | 0.114903    | n       |   0.0786177   | y       |   0.157235    |
|  2 | b       |   0.07473    | r       | 0.115335    | o       | 0.1054      | s       |   0.0738661   | t       |   0.109287    |
|  3 | t       |   0.0643629  | e       | 0.104536    | e       | 0.0764579   | a       |   0.0704104   | r       |   0.0915767   |
|  4 | p       |   0.0613391  | i       | 0.087257    | u       | 0.0712743   | l       |   0.0699784   | l       |   0.0673866   |
|  5 | a       |   0.0609071  | l       | 0.0868251   | r       | 0.0704104 