In [1]:
# Cell 1: Imports and Setup
import os
import numpy as np
import random
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from collections import Counter
import torch.nn.functional as F



In [2]:
# Import from modules
from wordle_core import vocabulary, wordle_feedback
from wordle_env import WordleEnv
from wordle_agent import WordleA2CNet, masked_softmax
from config import *
from evaluate import evaluate_agent_with_games


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
# Set seeds for reproducibility
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

<torch._C.Generator at 0x7b2fe4488ed0>

In [4]:
# Load Vocabulary
vocab, solutions, VOCAB_SIZE, WORD_TO_IDX = vocabulary(VOCAB_PATH, SOLUTIONS_PATH)


VOCABULARY:
   Solutions: 2315 words (targets only)
   Allowed guesses: 12972 words (full action space)
   Training vocabulary: 12972 words


In [8]:
# Initialize Environment and Agent
train_env = WordleEnv(vocab, solutions, WORD_TO_IDX, training_mode=True)
eval_env = WordleEnv(vocab, solutions, WORD_TO_IDX, training_mode=False)

agent = WordleA2CNet(INPUT_DIM, VOCAB_SIZE).to(device)
optimizer = optim.Adam(agent.parameters(), lr=LR, weight_decay=1e-5)


print(f"   Parameters: {sum(p.numel() for p in agent.parameters()):,}")
print(f"   Device: {device}")


   Parameters: 2,373,933
   Device: cpu


In [130]:
# Training Setup
win_rates = []
avg_guesses_list = []
entropy_history = []
best_win_rate = 0.0

In [16]:
# Training Loop
for ep in range(NUM_EPISODES):
    obs, mask = train_env.reset()
    log_probs, values, rewards, entropies = [], [], [], []
    done = False

    while not done:
        obs_t = obs.unsqueeze(0).to(device)
        mask_t = mask.unsqueeze(0).to(device)

        logits, value = agent(obs_t)
        probs, current_temp = masked_softmax(logits, mask_t, ep, NUM_EPISODES)

        dist = torch.distributions.Categorical(probs)
        action = dist.sample()
        log_prob = dist.log_prob(action)
        entropy = dist.entropy()

        obs, reward, done, mask = train_env.step(action.item())

        log_probs.append(log_prob)
        values.append(value)
        rewards.append(reward)
        entropies.append(entropy)

    # A2C update
    R = 0
    returns_batch = []
    for r in reversed(rewards):
        R = r + GAMMA * R
        returns_batch.insert(0, R)

    returns_batch = torch.tensor(returns_batch, dtype=torch.float32).to(device)
    values = torch.cat(values)
    log_probs = torch.cat(log_probs)
    advantages = returns_batch - values

    actor_loss = -(log_probs * advantages.detach()).mean()
    critic_loss = F.mse_loss(values, returns_batch)
    entropy_loss = -ENTROPY_COEF * torch.stack(entropies).mean() if entropies else 0
    loss = actor_loss + critic_loss + entropy_loss

    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(agent.parameters(), max_norm=1.0)
    optimizer.step()

    if entropies:
        entropy_history.append(torch.stack(entropies).mean().item())

    # Evaluation
    if (ep + 1) % EVAL_EVERY == 0:
        win_rate, avg_guesses = evaluate_agent_with_games(ep + 1, num_display_games=2)
        win_rates.append(win_rate)
        avg_guesses_list.append(avg_guesses)

        print(f"   Training: Temp={current_temp:.2f}")

        if win_rate > best_win_rate:
            best_win_rate = win_rate
            torch.save({
                'model_state_dict': agent.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'win_rate': win_rate,
                'episode': ep
            }, SAVE_PATH)
            print(f" Saved best model (Win Rate: {win_rate:.1%})")
        print()

print(" Training complete!")

    Episode 200 - AGENT GUESSING PROCESS:

    Game 1: Target = HOLLY
   ----------------------------------------
   Turn 1: BIRKS
   Turn 2: TANGO
   Turn 3: LOVEY
   Turn 4: WOLLY
   Turn 5: FOLLY
   Turn 6: MOLLY
    FAILED

    Game 2: Target = OVERT
   ----------------------------------------
   Turn 1: BIRKS
   Turn 2: AWMRY
   Turn 3: OUTRE
   Turn 4: OVERT
   SOLVED in 4 turns!

   Summary of 100 games:
     Win Rate: 93.0% | Avg Guesses: 0.10
     Top First Guesses: [('birks', 98)]
     Current Temperature: 1.45
   Training: Temp=1.45

    Episode 400 - AGENT GUESSING PROCESS:

    Game 1: Target = CHAMP
   ----------------------------------------
   Turn 1: TAFFY
   Turn 2: SHEAL
   Turn 3: CHANK
   Turn 4: CHARR
   Turn 5: CHACO
   Turn 6: CHAMP
   SOLVED in 6 turns!

    Game 2: Target = LOOSE
   ----------------------------------------
   Turn 1: TAFFY
   Turn 2: BURRS
   Turn 3: CISCO
   Turn 4: SHMOE
   Turn 5: GOOSE
   Turn 6: LOOSE
   SOLVED in 6 turns!

   Summary of 

In [None]:
from google.colab import drive
drive.mount('/content/drive')