# Hangman ML Hackathon - Complete Solution

## ‚ö†Ô∏è **EXECUTION ORDER IS CRITICAL!** ‚ö†Ô∏è

**You MUST run cells in order from top to bottom!**

**Common Error: `NameError: name 'FirstOrderHMM' is not defined`**
- **Cause**: You ran a cell that uses `FirstOrderHMM` before running the cell that defines it
- **Fix**: Run Cell 16 (defines FirstOrderHMM class) before Cell 28 (uses FirstOrderHMM)

**Quick Execution Guide:**
1. **Cell 1**: Import libraries
2. **Cells 2-13**: Load and preprocess data  
3. **Cell 14**: Define `HangmanEnv` class
4. **Cell 15**: Test environment
5. **Cell 16**: **Define `FirstOrderHMM` class** ‚Üê MUST RUN BEFORE CELL 28!
6. **Cells 17-23**: HMM training and validation
7. **Cells 24-27**: Define RL agents (`QLearningAgent` and `DQNAgent`)
8. **Cell 28+**: Training and evaluation

**üöÄ Key Improvements:**
- HMM-Greedy Evaluation Mode
- Dynamic probability calculation per word
- DQN support for better learning
- Heavy HMM weighting (20x)


# Hangman ML Hackathon - Complete Solution

This notebook implements a complete Hangman solver using:
1. **Hidden Markov Model (HMM)** - Language model for letter probability estimation
2. **Reinforcement Learning (RL)** - Q-Learning agent for optimal guessing strategy
3. **Hybrid System** - Combines HMM + RL for intelligent gameplay

## Implementation Features

‚úÖ **Proper Training Workflow:**
- Stage 1: HMM Training (with validation)
- Stage 2: RL Baseline Setup
- Stage 3: Hybrid HMM + RL Training
- Stage 4: Final Evaluation

‚úÖ **Overfitting/Underfitting Prevention:**
- HMM: Perplexity validation, additive smoothing
- RL: Validation monitoring, noise injection, word shuffling
- Comprehensive diagnostics

‚úÖ **Online Learning:**
- HMM probabilities recalculated after each guess
- Q-values updated immediately after each action
- Step-by-step learning (not batch learning)

‚úÖ **Data Quality:**
- Comprehensive preprocessing (case, typos, duplicates)
- Word buckets by length
- Train/validation/test split

**Run cells sequentially from top to bottom!**


In [6]:
# Import all required libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict, Counter
import random
import pickle
import os
import sys

# Set random seeds for reproducibility
np.random.seed(42)
random.seed(42)

# Display settings
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")
%matplotlib inline

print("Libraries imported successfully!")


Libraries imported successfully!


## 1. Load Data


In [7]:
# Load data files
import os

# Get the directory containing this notebook
notebook_dir = os.path.dirname(os.path.abspath('__file__' if '__file__' in globals() else '.'))
# Go up one level to project root
project_root = os.path.dirname(notebook_dir) if os.path.basename(notebook_dir) == 'notebooks' else notebook_dir

# Try multiple possible paths
possible_corpus_paths = [
    os.path.join(project_root, 'Data', 'corpus.txt'),
    os.path.join('..', 'Data', 'corpus.txt'),
    'Data/corpus.txt',
    '../Data/corpus.txt',
    os.path.join(os.getcwd(), 'Data', 'corpus.txt')
]

possible_test_paths = [
    os.path.join(project_root, 'Data', 'test.txt'),
    os.path.join('..', 'Data', 'test.txt'),
    'Data/test.txt',
    '../Data/test.txt',
    os.path.join(os.getcwd(), 'Data', 'test.txt')
]

corpus_path = None
test_path = None

for path in possible_corpus_paths:
    if os.path.exists(path):
        corpus_path = path
        break

for path in possible_test_paths:
    if os.path.exists(path):
        test_path = path
        break

if corpus_path is None or test_path is None:
    print("ERROR: Could not find data files!")
    print(f"Tried corpus paths: {possible_corpus_paths}")
    print(f"Tried test paths: {possible_test_paths}")
    print(f"Current directory: {os.getcwd()}")
    print(f"Files in current dir: {os.listdir('.')}")
    raise FileNotFoundError("Data files not found! Please check the paths.")

print(f"Found corpus: {corpus_path}")
print(f"Found test: {test_path}")

# Load data
print("Loading raw data files...")
with open(corpus_path, 'r', encoding='utf-8') as f:
    corpus_raw = [line.strip() for line in f if line.strip()]

print(f"Raw corpus: {len(corpus_raw)} lines")

with open(test_path, 'r', encoding='utf-8') as f:
    test_words_raw = [line.strip() for line in f if line.strip()]

print(f"Raw test: {len(test_words_raw)} lines")
print(f"\nSample raw corpus words (first 10): {corpus_raw[:10]}")
print(f"Sample raw test words (first 10): {test_words_raw[:10]}")

if len(corpus_raw) == 0:
    raise ValueError("corpus_raw is empty! Check the data file.")
if len(test_words_raw) == 0:
    raise ValueError("test_words_raw is empty! Check the data file.")



Found corpus: /Users/vivannaik/Desktop/ml hackathon /Data/corpus.txt
Found test: /Users/vivannaik/Desktop/ml hackathon /Data/test.txt
Loading raw data files...
Raw corpus: 50000 lines
Raw test: 2000 lines

Sample raw corpus words (first 10): ['suburbanize', 'asmack', 'hypotypic', 'promoderationist', 'consonantly', 'philatelically', 'cacomelia', 'thicklips', 'luciferase', 'cinematography']
Sample raw test words (first 10): ['marmar', 'janet', 'dentistical', 'troveless', 'unnotify', 'gastrostenosis', 'preaffiliation', 'obpyriform', 'veratrinize', 'protection']


## 1.0 Data Cleaning and Preprocessing

Comprehensive preprocessing pipeline:
1. Case normalization (upper/lower case handling)
2. Remove non-alphabetic characters
3. Handle mispellings/typos (optional spell checking)
4. Remove duplicates
5. Filter by length and quality


In [15]:
import re
from difflib import SequenceMatcher

def normalize_case(word, case_mode='lower'):
    """
    Normalize word case.
    
    Args:
        word: Input word
        case_mode: 'lower', 'upper', or 'preserve'
    
    Returns:
        Normalized word
    """
    if case_mode == 'lower':
        return word.lower()
    elif case_mode == 'upper':
        return word.upper()
    else:
        return word

def remove_non_alphabetic(word):
    """
    Remove all non-alphabetic characters.
    
    Args:
        word: Input word
    
    Returns:
        Word with only alphabetic characters
    """
    return ''.join([c for c in word if c.isalpha()])

def remove_special_characters(word):
    """
    Remove special characters but keep alphabetic and basic characters.
    """
    # Keep only letters and basic characters
    return re.sub(r'[^a-zA-Z]', '', word)

def fix_common_typos(word):
    """
    Fix common typos and character confusions.
    """
    # Common typo fixes
    typo_fixes = {
        # Common character confusions
        '0': 'o', 'O': 'o',
        '1': 'i', 'I': 'i',
        '3': 'e',
        '5': 's',
        '@': 'a',
        # Remove numbers and special chars (handled separately)
    }
    
    fixed = word
    for typo, correct in typo_fixes.items():
        fixed = fixed.replace(typo, correct)
    
    return fixed

def is_valid_word(word, min_length=2, max_length=25):
    """
    Check if word is valid after cleaning.
    
    Args:
        word: Word to check
        min_length: Minimum valid length
        max_length: Maximum valid length
    
    Returns:
        bool: True if word is valid
    """
    if not word:
        return False
    
    if len(word) < min_length or len(word) > max_length:
        return False
    
    # Check if word contains only alphabetic characters
    if not word.isalpha():
        return False
    
    return True

def remove_duplicates_keep_order(words):
    """
    Remove duplicates while preserving order.
    """
    seen = set()
    unique_words = []
    for word in words:
        if word not in seen:
            seen.add(word)
            unique_words.append(word)
    return unique_words

def clean_and_preprocess_word(word, 
                             normalize_case_mode='lower',
                             remove_non_alpha=True,
                             fix_typos=True,
                             min_length=2,
                             max_length=25):
    """
    Complete preprocessing pipeline for a single word.
    
    Args:
        word: Raw word
        normalize_case_mode: 'lower', 'upper', or 'preserve'
        remove_non_alpha: Remove non-alphabetic characters
        fix_typos: Attempt to fix common typos
        min_length: Minimum word length
        max_length: Maximum word length
    
    Returns:
        Cleaned word or None if invalid
    """
    # Step 1: Normalize case
    cleaned = normalize_case(word, normalize_case_mode)
    
    # Step 2: Fix common typos (do before removing non-alpha to catch digit confusions)
    if fix_typos:
        cleaned = fix_common_typos(cleaned)
    
    # Step 3: Remove non-alphabetic characters
    if remove_non_alpha:
        cleaned = remove_non_alphabetic(cleaned)
    
    # Step 4: Validate
    if not is_valid_word(cleaned, min_length, max_length):
        return None
    
    return cleaned


  def preprocess_word_list(words, min_length=3, max_length=15):
    # Convert to lowercase and remove duplicates
    words = list({word.lower() for word in words})
    # Filter words by length and letters only
    return sorted([word for word in words 
                  if (min_length <= len(word) <= max_length and 
                      word.isalpha())])
    # Convert to lowercase and remove duplicates
    words = list({word.lower() for word in words})
    
    # Filter words based on length and character set (only letters)
    filtered_words = [
        word for word in words 
        if (min_length <= len(word) <= max_length and 
            word.isalpha())
    ]
    
    return sorted(filtered_words)
    return sorted(filtered_words)
    for word in words:
        original = word
        
        # Clean word
        cleaned = clean_and_preprocess_word(
            word,
            normalize_case_mode=normalize_case_mode,
            remove_non_alpha=remove_non_alpha,
            fix_typos=fix_typos,
            min_length=min_length,
            max_length=max_length
        )
        
        if cleaned is not None:
            cleaned_words.append(cleaned)
        else:
            stats['removed_words'].append(original)
    
    stats['after_length_filter'] = len(cleaned_words)
    
    # Remove duplicates
    if remove_duplicates:
        before_dup_removal = len(cleaned_words)
        cleaned_words = remove_duplicates_keep_order(cleaned_words)
        stats['after_duplicate_removal'] = len(cleaned_words)
        stats['duplicates_removed'] = before_dup_removal - len(cleaned_words)
    
    stats['final_count'] = len(cleaned_words)
    
    if verbose:
        print(f"Preprocessing Statistics:")
        print(f"  Original words: {stats['original_count']}")
        print(f"  After cleaning: {stats['final_count']}")
        print(f"  Removed: {stats['original_count'] - stats['final_count']}")
        if remove_duplicates:
            print(f"  Duplicates removed: {stats.get('duplicates_removed', 0)}")
        if stats.get('original_count', 0) > 0:
        if stats.get('original_count', 0) > 0:
            print(f"  Removal rate: {(stats['original_count'] - stats['final_count']) / stats['original_count'] * 100:.2f}%")
        else:
            print(f"  Removal rate: N/A (original count was 0)")
    else:
        print(f"  Removal rate: N/A (original count was 0)")
    
    return cleaned_words, stats

print("Preprocessing functions defined!")



IndentationError: unindent does not match any outer indentation level (<string>, line 136)

In [None]:
# Preprocess corpus
print("="*60)
print("PREPROCESSING CORPUS")
print("="*60)
corpus, corpus_stats = preprocess_word_list(
    corpus_raw,
    normalize_case_mode='lower',  # Convert all to lowercase
    remove_non_alpha=True,
    fix_typos=True,
    min_length=3,   # Minimum 3 letters (adjust as needed)
    max_length=20,  # Maximum 20 letters
    remove_duplicates=True,
    verbose=True
)

print(f"\nCorpus sample (first 20): {corpus[:20]}")

# Show some removed words (if any)
if len(corpus_stats['removed_words']) > 0:
    print(f"\nSample of removed words (first 10): {corpus_stats['removed_words'][:10]}")


In [None]:
# Preprocess test set
print("\n" + "="*60)
print("PREPROCESSING TEST SET")
print("="*60)
test_words, test_stats = preprocess_word_list(
    test_words_raw,
    normalize_case_mode='lower',  # Same as corpus
    remove_non_alpha=True,
    fix_typos=True,
    min_length=3,   # Same as corpus
    max_length=20,  # Same as corpus
    remove_duplicates=True,
    verbose=True
)

print(f"\nTest set sample (first 20): {test_words[:20]}")

# Check for overlap between corpus and test (data leakage check)
corpus_set = set(corpus)
test_set = set(test_words)
overlap = corpus_set & test_set

print(f"\n{'='*60}")
print("DATA LEAKAGE CHECK")
print(f"{'='*60}")
print(f"Words in corpus: {len(corpus_set)}")
print(f"Words in test: {len(test_set)}")
print(f"Overlapping words: {len(overlap)}")

if len(overlap) > 0:
    print(f"‚ö†Ô∏è  WARNING: {len(overlap)} words appear in both corpus and test set!")
    print(f"   Sample overlap: {list(overlap)[:10]}")
    print(f"   This may cause data leakage. Consider removing from test set.")
else:
    print(f"‚úÖ No overlap detected - good data separation!")


## 1.1 Data Quality Analysis


In [None]:
# Data quality analysis
def analyze_word_quality(words, name="Dataset"):
    """Analyze quality metrics of word list"""
    if not words:
        print(f"{name}: No words to analyze")
        return
    
    # Length distribution
    lengths = [len(w) for w in words]
    
    # Character analysis
    all_chars = ''.join(words)
    char_counts = Counter(all_chars)
    
    # Unique characters
    unique_chars = set(all_chars)
    
    print(f"\n{'='*60}")
    print(f"{name} QUALITY ANALYSIS")
    print(f"{'='*60}")
    print(f"Total words: {len(words)}")
    print(f"Unique words: {len(set(words))}")
    print(f"Average length: {np.mean(lengths):.2f} characters")
    print(f"Min length: {min(lengths)}")
    print(f"Max length: {max(lengths)}")
    print(f"Most common length: {Counter(lengths).most_common(1)[0][0]} letters ({Counter(lengths).most_common(1)[0][1]} words)")
    print(f"\nUnique characters: {len(unique_chars)}")
    print(f"All lowercase: {all(w.islower() for w in words)}")
    print(f"All alphabetic: {all(w.isalpha() for w in words)}")
    
    # Character frequency
    print(f"\nTop 10 most common characters:")
    for char, count in char_counts.most_common(10):
        freq = count / len(all_chars) * 100
        print(f"  '{char}': {count:,} times ({freq:.2f}%)")
    
    # Visualize length distribution
    plt.figure(figsize=(10, 5))
    plt.hist(lengths, bins=range(min(lengths), max(lengths)+2), edgecolor='black', alpha=0.7)
    plt.xlabel('Word Length')
    plt.ylabel('Frequency')
    plt.title(f'{name} - Word Length Distribution')
    plt.grid(True, alpha=0.3, axis='y')
    plt.tight_layout()
    plt.show()

# Analyze cleaned datasets
analyze_word_quality(corpus, "CORPUS")
analyze_word_quality(test_words, "TEST SET")


## 1.2 Split into Training and Validation Sets


In [None]:
# Split corpus into training and validation sets
# Use 90% for training, 10% for validation
split_idx = int(len(corpus) * 0.9)
training_corpus = corpus[:split_idx]
validation_corpus = corpus[split_idx:]

print(f"Data split:")
print(f"  Training: {len(training_corpus)} words ({len(training_corpus)/len(corpus)*100:.1f}%)")
print(f"  Validation: {len(validation_corpus)} words ({len(validation_corpus)/len(corpus)*100:.1f}%)")
print(f"  Test: {len(test_words)} words")

# Shuffle training data for better distribution
random.shuffle(training_corpus)
random.shuffle(validation_corpus)

print(f"\n‚úÖ Data preprocessing complete!")
print(f"   All words are lowercase, alphabetic, and within length range [3, 20]")


# Guards to ensure non-empty datasets
if len(training_corpus) == 0 and len(corpus) > 0:
    training_corpus = corpus[: max(1, int(0.9*len(corpus)))]
if len(validation_corpus) == 0 and len(corpus) > 0:
    validation_corpus = corpus[-max(1, len(corpus)-len(training_corpus)):]
if 'test_words' in globals() and len(test_words) == 0 and 'test_words_raw' in globals() and len(test_words_raw) > 0:
    test_words = [w.strip().lower() for w in test_words_raw if w and w.strip().isalpha() and 3 <= len(w) <= 20]



## 1.1 Organize Words into Buckets by Length


## 4.1 RL Agent - Hyperparameters and Tuning


## 2. Hangman Environment

**IMPORTANT**: Run this cell to define the `HangmanEnv` class before using it!


In [None]:
# RL Agent Hyperparameters
RL_CONFIG = {
    'learning_rate': 0.2,        # Œ±: Further increased for faster learning
    'discount_factor': 0.99,     # Œ≥: Higher for very long-term planning
    'epsilon': 1.0,              # Initial exploration (100% random)
    'epsilon_decay': 0.999,      # Slower decay = more exploration time
    'epsilon_min': 0.02,         # Lower min (less random at end, trust HMM more)
,
    'q_weight': 1.0
}

print("="*70)
print("RL AGENT HYPERPARAMETERS")
print("="*70)
print(f"Learning Rate (Œ±): {RL_CONFIG['learning_rate']}")
print(f"  ‚Üí Too high: Oscillations, unstable learning")
print(f"  ‚Üí Too low: Slow convergence")
print(f"  ‚Üí Recommended: 0.05-0.2 for Q-learning")
print(f"\nDiscount Factor (Œ≥): {RL_CONFIG['discount_factor']}")
print(f"  ‚Üí High (0.95-0.99): Focus on long-term rewards")
print(f"  ‚Üí Low (0.7-0.9): Focus on immediate rewards")
print(f"\nExploration Schedule:")
print(f"  ‚Üí Initial Œµ: {RL_CONFIG['epsilon']} (100% exploration)")
print(f"  ‚Üí Decay: {RL_CONFIG['epsilon_decay']} per episode")
print(f"  ‚Üí Final Œµ: {RL_CONFIG['epsilon_min']} (1% exploration)")
print(f"  ‚Üí Episodes to 50%: {np.log(0.5) / np.log(RL_CONFIG['epsilon_decay']):.0f}")
print(f"\nüí° Tuning Tips:")
print(f"  - If win rate not improving: Try higher learning rate or stronger rewards")
print(f"  - If too random: Increase epsilon_decay (faster exploitation)")
print(f"  - If stuck in local optimum: Increase epsilon_min (more exploration)")
print("="*70)




In [None]:
# Evaluation mode: 'hybrid' uses Q+HMM, 'hmm' uses pure HMM
EVAL_MODE = 'hybrid'
print('EVAL_MODE =', EVAL_MODE)


## 3. HMM Model

**‚ö†Ô∏è CRITICAL**: Run this cell to define `FirstOrderHMM` class!  
**You MUST run this before any cell that uses `hmm = FirstOrderHMM(...)`**


In [None]:
# Organize words into buckets by length
def bucket_words_by_length(words, min_length=3, max_length=20):
    """
    Organize words into buckets by length.
    
    Returns:
        dict: {length: [list of words]}
    """
    buckets = defaultdict(list)
    
    for word in words:
        word_len = len(word)
        if min_length <= word_len <= max_length:
            buckets[word_len].append(word)
    
    return buckets

# Bucket all datasets by length
training_buckets = bucket_words_by_length(training_corpus)
validation_buckets = bucket_words_by_length(validation_corpus)
test_buckets = bucket_words_by_length(test_words)

print("Word buckets by length:")
print(f"\nTraining corpus buckets:")
for length in sorted(training_buckets.keys()):
    print(f"  Length {length:2d}: {len(training_buckets[length]):5d} words")

print(f"\nValidation corpus buckets:")
for length in sorted(validation_buckets.keys()):
    print(f"  Length {length:2d}: {len(validation_buckets[length]):5d} words")

print(f"\nTest corpus buckets:")
for length in sorted(test_buckets.keys()):
    print(f"  Length {length:2d}: {len(test_buckets[length]):5d} words")

# Visualize word length distribution
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Training distribution
train_lengths = sorted(training_buckets.keys())
train_counts = [len(training_buckets[l]) for l in train_lengths]
axes[0].bar(train_lengths, train_counts, color='blue', alpha=0.7)
axes[0].set_xlabel('Word Length')
axes[0].set_ylabel('Count')
axes[0].set_title('Training Corpus - Word Length Distribution')
axes[0].grid(True, alpha=0.3, axis='y')

# Validation distribution
val_lengths = sorted(validation_buckets.keys())
val_counts = [len(validation_buckets[l]) for l in val_lengths]
axes[1].bar(val_lengths, val_counts, color='orange', alpha=0.7)
axes[1].set_xlabel('Word Length')
axes[1].set_ylabel('Count')
axes[1].set_title('Validation Corpus - Word Length Distribution')
axes[1].grid(True, alpha=0.3, axis='y')

# Test distribution
test_lengths = sorted(test_buckets.keys())
test_counts = [len(test_buckets[l]) for l in test_lengths]
axes[2].bar(test_lengths, test_counts, color='green', alpha=0.7)
axes[2].set_xlabel('Word Length')
axes[2].set_ylabel('Count')
axes[2].set_title('Test Corpus - Word Length Distribution')
axes[2].grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

# Summary statistics
print(f"\nSummary Statistics:")
print(f"Training: {sum(len(words) for words in training_buckets.values())} words")
print(f"Validation: {sum(len(words) for words in validation_buckets.values())} words")
print(f"Test: {sum(len(words) for words in test_buckets.values())} words")
print(f"\nMost common length in training: {max(training_buckets.keys(), key=lambda k: len(training_buckets[k]))} letters")


## 2.3 Letter Frequency Heuristics

Common English letter frequencies as fallback when HMM is uncertain.


In [None]:
# Note: We NO LONGER use predetermined letter frequencies
# Instead, HMM calculates DYNAMIC probabilities for EACH word based on:
# - Word-specific context (known letters, position)
# - Bigram patterns in this specific word
# - Position-specific probabilities for this word length
# - Transition probabilities from this word's context

print("="*70)
print("DYNAMIC PROBABILITY CALCULATION")
print("="*70)
print("‚úì Probabilities calculated dynamically for EACH word")
print("‚úì Uses word-specific context (length, position, known letters)")
print("‚úì Top 5 suggestions generated dynamically per word")
print("‚úì NO predetermined/fixed probabilities")
print("="*70)


## 2. Hangman Environment


In [None]:
class HangmanEnv:
    """Hangman game environment"""
    
    def __init__(self, word, max_lives=6, max_guesses=25):
        self.word = word.lower()
        self.max_lives = max_lives
        self.max_guesses = max_guesses
        self.lives = max_lives
        self.guessed_letters = set()
        self.masked_word = ['_' for _ in self.word]
        self.num_guesses = 0
        
        # Reward parameters (tunable)
        self.reward_correct = 0.5
        self.reward_wrong = -1.0
        self.reward_repeated = -0.5
        self.reward_win = 10.0
        self.reward_lose = -5.0
    
    def get_state(self):
        """Get current state representation"""
        masked_str = ''.join(self.masked_word)
        return {
            'masked_word': masked_str,
            'guessed_letters': self.guessed_letters.copy(),
            'lives_left': self.lives,
            'word_length': len(self.word),
            'num_guesses': self.num_guesses
        }
    
    def guess_letter(self, letter):
        """
        Guess a letter. Returns (reward, new_state, done, info)
        """
        letter = letter.lower()
        self.num_guesses += 1
        
        # Check repeated guess
        if letter in self.guessed_letters:
            reward = self.reward_repeated
            return reward, self.get_state(), False, {'status': 'repeated'}
        
        self.guessed_letters.add(letter)
        
        # Check if correct
        if letter in self.word:
            # Update masked word
            for i, char in enumerate(self.word):
                if char == letter:
                    self.masked_word[i] = letter
            
            # Check win
            if '_' not in self.masked_word:
                reward = self.reward_win
                return reward, self.get_state(), True, {'status': 'won'}
            else:
                reward = self.reward_correct
                return reward, self.get_state(), False, {'status': 'correct'}
        else:
            # Wrong guess
            self.lives -= 1
            
            if self.lives == 0 or self.num_guesses >= self.max_guesses:
                reward = self.reward_lose
                return reward, self.get_state(), True, {'status': 'lost'}
            else:
                reward = self.reward_wrong
                return reward, self.get_state(), False, {'status': 'wrong'}
    
    def reset(self, word=None):
        """Reset environment with new word"""
        if word:
            self.word = word.lower()
        self.lives = self.max_lives
        self.guessed_letters = set()
        self.masked_word = ['_' for _ in self.word]
        self.num_guesses = 0
        return self.get_state()

# Test environment
print("Testing Hangman Environment...")
env = HangmanEnv("apple")
print(f"Initial state: {env.get_state()['masked_word']}")

reward, state, done, info = env.guess_letter('a')
print(f"Guess 'a': {state['masked_word']}, reward={reward}, done={done}")

reward, state, done, info = env.guess_letter('p')
print(f"Guess 'p': {state['masked_word']}, reward={reward}, done={done}")

reward, state, done, info = env.guess_letter('x')
print(f"Guess 'x': {state['masked_word']}, reward={reward}, lives={state['lives_left']}, done={done}")
print("Environment working correctly!")


## 3. HMM Model


In [None]:
class FirstOrderHMM:
    """
    1st-order HMM: P(word) = P(l1, l2, ..., ln) = ‚àè P(li | li-1)
    
    Structure:
    - Hidden states: letters (26 states) + start/end tokens
    - Emissions: letters (26 observations)
    - Transitions: A[i][j] = P(letter_j | letter_i)
    - Emissions: B[i][k] = P(obs_k | state_i) = identity (state emits itself)
    """
    
    def __init__(self, smoothing=0.01):
        self.smoothing = smoothing  # Additive (Laplace) smoothing parameter
        self.vocab_size = 26
        
        # Transition matrix: A[i][j] = P(j | i)
        # i, j are letter indices (0=a, 1=b, ..., 25=z)
        self.transition_counts = defaultdict(lambda: defaultdict(int))
        self.transition_probs = defaultdict(lambda: defaultdict(float))
        
        # Start probabilities: P(first letter)
        self.start_counts = defaultdict(int)
        self.start_probs = defaultdict(float)
        
        # End probabilities: P(end | last letter)
        self.end_counts = defaultdict(int)
        self.end_probs = defaultdict(float)
        
        # Letter frequency (for fallback)
        self.letter_freq = defaultdict(int)
        self.total_letters = 0
        
        # Position-based probabilities (fallback for different word lengths)
        self.position_letter_counts = defaultdict(lambda: defaultdict(lambda: defaultdict(int)))
        
        # Bigram patterns for common word structures
        self.bigram_counts = defaultdict(int)
        
    def train(self, corpus, max_order=1):
        """Train 1st-order HMM on corpus"""
        print("Training 1st-order HMM...")
        
        # Count transitions and positions
        # Also track bigrams and common patterns
        bigram_counts = defaultdict(int)  # Track common bigrams like "AP", "PP", etc.
        
        for word in corpus:
            if len(word) == 0:
                continue
            
            # Track position-based counts (for position-dependent predictions)
            word_len = len(word)
            for pos, letter in enumerate(word):
                self.position_letter_counts[word_len][pos][letter] += 1
                self.letter_freq[letter] += 1
                self.total_letters += 1
            
            # Track transitions (1st-order: letter depends on previous letter)
            for i in range(len(word)):
                letter = word[i]
                
                if i == 0:
                    # Start probability
                    self.start_counts[letter] += 1
                else:
                    # Transition: previous letter -> current letter
                    prev_letter = word[i-1]
                    self.transition_counts[prev_letter][letter] += 1
                    # Track bigram
                    bigram = prev_letter + letter
                    bigram_counts[bigram] += 1
                
                # End probability: last letter -> end
                if i == len(word) - 1:
                    self.end_counts[letter] += 1
        
        # Store bigram counts for pattern matching
        self.bigram_counts = bigram_counts
        
        # Compute transition probabilities with smoothing
        print("Computing transition probabilities...")
        for prev_letter in self.transition_counts:
            total = sum(self.transition_counts[prev_letter].values())
            for letter in 'abcdefghijklmnopqrstuvwxyz':
                count = self.transition_counts[prev_letter].get(letter, 0)
                # Additive smoothing: P(j|i) = (count + Œ±) / (total + Œ± * vocab_size)
                self.transition_probs[prev_letter][letter] = (count + self.smoothing) / (total + self.smoothing * self.vocab_size)
        
        # Compute start probabilities
        total_starts = sum(self.start_counts.values())
        for letter in 'abcdefghijklmnopqrstuvwxyz':
            count = self.start_counts.get(letter, 0)
            self.start_probs[letter] = (count + self.smoothing) / (total_starts + self.smoothing * self.vocab_size)
        
        # Compute end probabilities
        total_ends = sum(self.end_counts.values())
        for letter in 'abcdefghijklmnopqrstuvwxyz':
            count = self.end_counts.get(letter, 0)
            self.end_probs[letter] = (count + self.smoothing) / (total_ends + self.smoothing * self.vocab_size)
        
        print(f"Trained on {len(corpus)} words")
        print(f"Total letter transitions: {sum(sum(v.values()) for v in self.transition_counts.values())}")
    
    def get_letter_probability_given_prev(self, prev_letter, letter):
        """Get P(letter | prev_letter) using transition probabilities"""
        if prev_letter is None:
            # Start probability
            return self.start_probs.get(letter, self.smoothing / self.vocab_size)
        return self.transition_probs.get(prev_letter, {}).get(letter, self.smoothing / self.vocab_size)
    
    def get_letter_probability_by_position(self, word_length, position, letter):
        """Get P(letter | position, length) - position-based fallback"""
        if word_length not in self.position_letter_counts:
            return self.letter_freq.get(letter, self.smoothing) / (self.total_letters + self.smoothing * self.vocab_size)
        
        if position not in self.position_letter_counts[word_length]:
            return self.letter_freq.get(letter, self.smoothing) / (self.total_letters + self.smoothing * self.vocab_size)
        
        counts = self.position_letter_counts[word_length][position]
        total = sum(counts.values()) + self.smoothing * self.vocab_size
        
        letter_count = counts.get(letter, 0) + self.smoothing
        return letter_count / total
    
    def get_letter_probability_given_next(self, letter, next_letter):
        """Get P(letter | next_letter) - reverse transition"""
        if next_letter is None:
            # End probability (reverse)
            return self.end_probs.get(letter, self.smoothing / self.vocab_size)
        # We need reverse transitions: P(letter | next_letter)
        # Approximate using forward transitions: P(next_letter | letter)
        # Then use Bayes' rule approximation
        forward_prob = self.transition_probs.get(letter, {}).get(next_letter, self.smoothing / self.vocab_size)
        letter_freq = self.letter_freq.get(letter, self.smoothing) / max(self.total_letters, 1)
        next_freq = self.letter_freq.get(next_letter, self.smoothing) / max(self.total_letters, 1)
        if next_freq > 0:
            reverse_prob = (forward_prob * letter_freq) / next_freq
            return min(reverse_prob, 1.0)
        return forward_prob
    
    def get_probabilities_for_mask(self, masked_word, guessed_letters=set()):
        """
        Get probability distribution over alphabet for given masked word.
        IMPROVED: Uses bidirectional context (previous AND next letters).
        """
        word_length = len(masked_word)
        letter_probs = defaultdict(float)
        
        # Get known letters and their positions
        known_letters = {}
        blank_positions = []
        for i, char in enumerate(masked_word):
            if char == '_':
                blank_positions.append(i)
            else:
                known_letters[i] = char
        
        if not blank_positions:
            return {letter: 0 for letter in 'abcdefghijklmnopqrstuvwxyz'}
        
        # For each blank position, calculate probabilities using BIDIRECTIONAL context
        for letter in 'abcdefghijklmnopqrstuvwxyz':
            if letter in guessed_letters:
                letter_probs[letter] = 0
                continue
            
            position_probs = []
            
            for blank_pos in blank_positions:
                # Strategy 1: Forward context (previous letter)
                prev_letter = None
                for i in range(blank_pos - 1, -1, -1):
                    if i in known_letters:
                        prev_letter = known_letters[i]
                        break
                
                forward_prob = self.get_letter_probability_given_prev(prev_letter, letter)
                
                # Strategy 2: Backward context (next letter) - NEW!
                next_letter = None
                for i in range(blank_pos + 1, len(masked_word)):
                    if i in known_letters:
                        next_letter = known_letters[i]
                        break
                
                backward_prob = self.get_letter_probability_given_next(letter, next_letter)
                
                # Strategy 3: Position-based probabilities
                pos_prob = self.get_letter_probability_by_position(word_length, blank_pos, letter)
                
                # Strategy 4: Bigram pattern matching (improved)
                # Check if letter fits common patterns
                pattern_boost = 1.0
                
                # Check bigram patterns
                if prev_letter:
                    bigram_prev = prev_letter + letter
                    bigram_freq = self.bigram_counts.get(bigram_prev, 0)
                    if bigram_freq > 100:  # Common bigram
                        pattern_boost *= (1.0 + min(bigram_freq / 1000, 1.0))  # Boost up to 2x
                
                if next_letter:
                    bigram_next = letter + next_letter
                    bigram_freq = self.bigram_counts.get(bigram_next, 0)
                    if bigram_freq > 100:  # Common bigram
                        pattern_boost *= (1.0 + min(bigram_freq / 1000, 1.0))  # Boost up to 2x
                
                # Special case: Common patterns like "_PPLE" -> should suggest 'A'
                if blank_pos == 0 and next_letter == 'p':
                    # Check if "AP" is a common bigram
                    if self.bigram_counts.get('ap', 0) > 50:
                        if letter == 'a':
                            pattern_boost *= 3.0  # Strong boost for "APPLE" pattern
                
                # Pattern like "_HAT" -> should suggest 'C' or 'T'
                if blank_pos == 0 and next_letter == 'h':
                    # Common patterns: "CH", "TH", "WH"
                    if letter in ['c', 't', 'w']:
                        bigram = letter + 'h'
                        if self.bigram_counts.get(bigram, 0) > 50:
                            pattern_boost *= 2.5
                
                # Combine all strategies with better weighting
                # Forward context: 40%
                # Backward context: 30% (NEW - helps with patterns like "_PPLE")
                # Position: 20%
                # Pattern boost: multiplier
                combined_prob = (0.4 * forward_prob + 
                               0.3 * backward_prob + 
                               0.3 * pos_prob) * pattern_boost
                
                position_probs.append(combined_prob)
            
            # Use MAX instead of MEAN (if letter fits ANY blank well, it's promising)
            # But also consider all positions
            letter_probs[letter] = max(position_probs) * 0.7 + np.mean(position_probs) * 0.3
        
        # Normalize
        total = sum(letter_probs.values())
        if total > 0:
            letter_probs = {k: v/total for k, v in letter_probs.items()}
        
        return letter_probs
    
    def calculate_log_likelihood(self, word):
        """Calculate log-likelihood of a word: log P(word)"""
        if len(word) == 0:
            return float('-inf')
        
        log_prob = 0.0
        
        for i in range(len(word)):
            letter = word[i]
            
            if i == 0:
                # Start probability
                log_prob += np.log(max(self.start_probs.get(letter, self.smoothing / self.vocab_size), 1e-10))
            else:
                # Transition probability
                prev_letter = word[i-1]
                trans_prob = self.transition_probs.get(prev_letter, {}).get(letter, self.smoothing / self.vocab_size)
                log_prob += np.log(max(trans_prob, 1e-10))
        
        return log_prob
    
    def calculate_perplexity(self, corpus):
        """Calculate perplexity = exp(-average log likelihood)"""
        total_log_likelihood = 0.0
        total_letters = 0
        
        for word in corpus:
            log_likelihood = self.calculate_log_likelihood(word)
            total_log_likelihood += log_likelihood
            total_letters += len(word)
        
        avg_log_likelihood = total_log_likelihood / total_letters if total_letters > 0 else 0
        perplexity = np.exp(-avg_log_likelihood)
        
        return perplexity

print("1st-order HMM class defined!")


## 2.2 Reward Function Design

**Critical for RL success!** Reward design heavily impacts agent behavior.


In [None]:
# Reward Function Configuration
# These parameters are CRITICAL for RL training success!

REWARD_CONFIG = {
    'correct': 0.5,      # +reward for correct guess (immediate feedback)
    'wrong': -1.0,       # -penalty for wrong guess (balance exploration)
    'repeated': -0.5,    # -penalty for repeated guess (discourage inefficiency)
    'win': 10.0,         # +bonus for winning (encourage completion)
    'lose': -5.0,        # -penalty for losing (less than win bonus)
}

print("="*70)
print("REWARD FUNCTION DESIGN")
print("="*70)
print("Reward Parameters:")
print(f"  Correct guess:     +{REWARD_CONFIG['correct']:.1f}")
print(f"  Wrong guess:        {REWARD_CONFIG['wrong']:.1f}")
print(f"  Repeated guess:     {REWARD_CONFIG['repeated']:.1f}")
print(f"  Win game:           +{REWARD_CONFIG['win']:.1f}")
print(f"  Lose game:          {REWARD_CONFIG['lose']:.1f}")

print(f"\nüí° Tuning Guidelines:")
print(f"  - Correct guess: 0.5-2.0 (immediate positive feedback)")
print(f"  - Wrong guess: -0.5 to -2.0 (balance exploration vs caution)")
print(f"  - Win bonus: 10-50 (encourage completing games)")
print(f"  - Lose penalty: -5 to -20 (less than win bonus)")
print(f"\nCurrent Balance:")
total_win_reward = REWARD_CONFIG['win'] + (REWARD_CONFIG['correct'] * 5)  # Assume 5 correct guesses
total_lose_penalty = abs(REWARD_CONFIG['lose']) + (abs(REWARD_CONFIG['wrong']) * 6)  # 6 wrong = lose
print(f"  Typical win reward: ~{total_win_reward:.1f}")
print(f"  Typical lose penalty: ~{total_lose_penalty:.1f}")
if total_win_reward > total_lose_penalty:
    print(f"  ‚úì Win reward > lose penalty (good - encourages playing)")
else:
    print(f"  ‚ö†Ô∏è  Win reward <= lose penalty (may discourage playing)")

print("="*70)

# Update HangmanEnv to use these rewards
# (Already defined in HangmanEnv class, but we can verify)
print("\n‚úì HangmanEnv uses these reward values")
print("  (Defined in HangmanEnv.__init__)")


## 2.1 HMM Training - Theory and Implementation

**Training Objective:**
Maximize likelihood: P(word) = ‚àè P(li | li-1)

**Components:**
- **Hidden States**: Letter positions/clusters
- **Transitions**: A[i][j] = P(letter_j | letter_i)
- **Emissions**: B[j][k] = P(obs_k | state_j)

**Key Features:**
- ‚úÖ Additive smoothing (Œ± ‚âà 0.01-0.1) to prevent overfitting
- ‚úÖ 1st-order HMM: letter depends on previous letter
- ‚úÖ Position-based fallback for different word lengths
- ‚úÖ Bigram pattern matching for common structures

**Validation:**
- Monitor training vs validation perplexity
- Check for overfitting (validation >> training)
- Use held-out validation set (10%)


## 4.1 Deep Q-Network (DQN) Agent

**DQN Advantages:**
- Uses neural network to approximate Q-values (handles large state spaces)
- Experience replay for better sample efficiency
- Target network for stable training
- Can learn complex patterns better than table-based Q-learning


### Note: DQN fully removed. Using optimized Q-learning only.


In [None]:
# Train HMM on training corpus (NOT validation/test)
hmm = FirstOrderHMM(smoothing=0.01)
hmm.train(training_corpus)
print("\nHMM training complete!")

# Evaluate HMM on training and validation sets (check for overfitting)
print("\nEvaluating HMM...")
train_perplexity = hmm.calculate_perplexity(training_corpus[:1000])  # Sample for speed
val_perplexity = hmm.calculate_perplexity(validation_corpus[:100])   # Sample for speed

print(f"Training Perplexity (sample): {train_perplexity:.2f}")
print(f"Validation Perplexity (sample): {val_perplexity:.2f}")
print(f"Difference: {abs(train_perplexity - val_perplexity):.2f}")

# Check for overfitting: if validation perplexity >> training, model is overfitting
if val_perplexity > train_perplexity * 1.5:
    print("‚ö†Ô∏è  WARNING: Possible overfitting detected!")
    print("   Consider: reducing model complexity, increasing smoothing, or early stopping")
elif abs(train_perplexity - val_perplexity) < train_perplexity * 0.1:
    print("‚úÖ HMM generalization looks good!")
else:
    print("‚ÑπÔ∏è  HMM performance: validation slightly higher than training (expected)")


## 2.2 HMM Hyperparameters and Tuning


In [None]:
# HMM Hyperparameters (tunable)
HMM_SMOOTHING = 0.01  # Additive smoothing parameter (Œ±)
# Typical range: 0.01-0.1
# Lower = more confident, but risk overfitting
# Higher = more conservative, but may underfit

print(f"HMM Configuration:")
print(f"  Smoothing (Œ±): {HMM_SMOOTHING}")
print(f"  Model order: 1st-order (letter depends on previous letter)")
print(f"  Validation split: 10%")
print(f"\nSmoothing effect:")
print(f"  P(letter|prev) = (count + {HMM_SMOOTHING}) / (total + {HMM_SMOOTHING} * 26)")

# Tuning guidance
print(f"\nüí° Tuning Tips:")
print(f"  - If validation perplexity >> training: increase smoothing (reduce overfitting)")
print(f"  - If perplexity too uniform: decrease smoothing (increase capacity)")
print(f"  - Good balance: training ‚âà validation perplexity")


In [None]:
# Test HMM on example masked words
test_cases = [
    "_PPLE",      # Should suggest 'A'
    "_PP_E",      # Should suggest 'A'
    "____",       # 4-letter word
    "_HAT",       # Should suggest 'C' or 'T'
    "H_NG_AN",    # Should suggest 'A' for both blanks
]

print("Testing HMM predictions:")
for masked in test_cases:
    probs = hmm.get_probabilities_for_mask(masked)
    top_5 = sorted(probs.items(), key=lambda x: x[1], reverse=True)[:5]
    print(f"\nMasked word: {masked}")
    print("Top 5 letter predictions:")
    for letter, prob in top_5:
        print(f"  {letter}: {prob:.4f}")

# Visualize transition matrix heatmap
print("\n\nVisualizing transition matrix (most common transitions)...")
letters = list('abcdefghijklmnopqrstuvwxyz')
transition_matrix = np.zeros((26, 26))

for i, prev in enumerate(letters):
    for j, curr in enumerate(letters):
        transition_matrix[i, j] = hmm.get_letter_probability_given_prev(prev, curr)

plt.figure(figsize=(12, 10))
sns.heatmap(transition_matrix, xticklabels=letters, yticklabels=letters,
            cmap='YlOrRd', cbar_kws={'label': 'Transition Probability'},
            fmt='.3f', annot=False)
plt.xlabel('Next Letter')
plt.ylabel('Previous Letter')
plt.title('HMM Transition Matrix: P(Next Letter | Previous Letter)')
plt.tight_layout()
plt.show()

print("\nNote: Diagonal patterns indicate common letter sequences (e.g., double letters)")
print("High values show common letter bigrams in the corpus")


## 4. Q-Learning Agent


In [None]:
class QLearningAgent:
    """Q-Learning agent for Hangman"""
    
    def __init__(self, learning_rate=0.1, discount_factor=0.95, 
                 epsilon=1.0, epsilon_decay=0.995, epsilon_min=0.01):
        # Learning parameters
        self.learning_rate = learning_rate        # Œ±
        self.discount_factor = discount_factor    # Œ≥
        
        # Exploration parameters
        self.epsilon = epsilon                    # Initial exploration
        self.epsilon_decay = epsilon_decay       # Decay per episode
        self.epsilon_min = epsilon_min           # Minimum exploration
        
        # Hybrid weights for combining Q and HMM
        self.hmm_weight = hmm_weight
        self.q_weight = q_weight

        # Q-table: state ‚Üí action ‚Üí Q-value
        self.Q = defaultdict(lambda: defaultdict(float))
        
        # Training statistics
        self.training_stats = {
            'episodes': [],
            'rewards': [],
            'wins': [],
            'losses': []
        }
    
    def state_to_key(self, state, hmm_probs):
        """Convert state to string key for Q-table"""
        masked = state['masked_word']
        guessed = ''.join(sorted(state['guessed_letters']))
        lives = state['lives_left']
        word_len = state['word_length']
        
        # Include word length and lives for better state representation
        return f"{masked}:{word_len}:{guessed}:{lives}"
    
    def get_available_actions(self, state):
        """Get list of available letters to guess"""
        all_letters = set('abcdefghijklmnopqrstuvwxyz')
        return sorted(all_letters - state['guessed_letters'])
    
    def select_action(self, state, hmm_probs):
        """Select action using Œµ-greedy policy"""
        available_actions = self.get_available_actions(state)
        
        if not available_actions:
            return None
        
        # Low-lives safeguard: switch to HMM-greedy to minimize wrong guesses
        if state.get('lives_left', 6) <= 2:
            available_actions = self.get_available_actions(state)
            if available_actions:
                ordered = sorted([(a, hmm_probs.get(a,0.0)) for a in available_actions], key=lambda x: x[1], reverse=True)
                return ordered[0][0]

        # Exploration: use HMM probabilities (smart exploration)
        if random.random() < self.epsilon:
            # Information-gain driven exploration (with HMM guidance)
            ig_map = info_gain_for_state(state)
            cand = []
            for a in available_actions:
                cand.append((a, 0.6*ig_map.get(a,0.0) + 0.4*hmm_probs.get(a,0.0)))
            cand.sort(key=lambda x: x[1], reverse=True)
            top_k = [a for a,_ in cand[:8]]
            return random.choice(top_k if top_k else available_actions)
        
        # Exploitation: choose best action based on Q-values + HMM
        state_key = self.state_to_key(state, hmm_probs)
        
        best_action = None
        best_value = float('-inf')
        
        for action in available_actions:
            q_value = self.Q[state_key][action]
            # Combine Q-value with HMM probability (weight HMM)
            hmm_weight = hmm_probs.get(action, 0) * 2
            combined = self.q_weight * q_value + self.hmm_weight * hmm_probs.get(action, 0)
            
            if combined > best_value:
                best_value = combined
                best_action = action
        
        if best_action is None:
            best_action = random.choice(available_actions)
        
        return best_action
    
    def update(self, state, action, reward, next_state, hmm_probs, done):
        """Update Q-value using Q-learning - IMPROVED with HMM initialization"""
        state_key = self.state_to_key(state, hmm_probs)
        next_state_key = self.state_to_key(next_state, hmm_probs)
        
        current_q = self.Q[state_key][action]
        
        # IMPROVEMENT: Initialize Q-value using HMM if never seen
        if current_q == 0 and action in hmm_probs:
            hmm_prob = hmm_probs.get(action, 0)
            current_q = hmm_prob * 2  # Optimistic initialization
        
        if done:
            max_next_q = 0
        else:
            available_actions = self.get_available_actions(next_state)
            if available_actions:
                max_next_q = max([self.Q[next_state_key][a] for a in available_actions], default=0)
                # Initialize next state using HMM if no Q-values
                if max_next_q == 0:
                    next_hmm_values = [hmm_probs.get(a, 0) * 2 for a in available_actions]
                    max_next_q = max(next_hmm_values) if next_hmm_values else 0
            else:
                max_next_q = 0
        
        # Q-learning update with adaptive learning rate
        adaptive_lr = self.learning_rate * (2.0 if current_q == 0 else 1.0)
        new_q = current_q + adaptive_lr * (reward + self.discount_factor * max_next_q - current_q)
        self.Q[state_key][action] = new_q
        
        # Decay epsilon
        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay
    
    def save(self, filepath):
        """Save agent"""
        with open(filepath, 'wb') as f:
            pickle.dump(self, f)
        print(f"Agent saved to {filepath}")
    
    @staticmethod
    def load(filepath):
        """Load agent"""
        with open(filepath, 'rb') as f:
            return pickle.load(f)

print("Q-Learning Agent class defined!")







## 5.1 STAGE 3: Hybrid HMM + RL Training

**Integration Strategy:**
- Combine HMM probabilities with RL state representation
- Use HMM to guide exploration (smart exploration)
- RL learns optimal policy given HMM information
- Add noise to HMM occasionally to prevent over-reliance

**Overfitting Prevention:**
- ‚úÖ Periodic word shuffling (prevent memorization)
- ‚úÖ Noise injection in HMM probabilities (10% of time)
- ‚úÖ Validation monitoring during training
- ‚úÖ Train/test separation (no data leakage)


In [None]:
# Initialize agent
agent = QLearningAgent(**RL_CONFIG)

print("Agent initialized!")
print(f"Initial epsilon: {agent.epsilon}")



## 5. Training Loop - Online Learning

**Key Features:**
- ‚úÖ **HMM probabilities recalculated after EVERY guess** - As masked_word changes, probabilities update
- ‚úÖ **Q-values updated IMMEDIATELY after each action** - Online learning (not batch learning)
- ‚úÖ **Step-by-step learning** - Agent learns from each guess, not just at episode end

**Learning Flow:**
1. Get HMM probabilities for current state (masked_word, guessed_letters)
2. Agent selects action (letter to guess)
3. Execute action ‚Üí get reward, new state
4. Recalculate HMM probabilities for NEW state (probabilities change after reveal!)
5. Update Q-values IMMEDIATELY using (state, action, reward, next_state)
6. Move to next iteration

This is **online/temporal-difference learning** - learns continuously, not in batches!


In [None]:
# Training parameters
NUM_EPISODES = 5000  # Further increased for better performance (was 3000)
TRAINING_SUBSET = 10000  # More diverse training words (was 8000)

# IMPORTANT: Use only TRAINING corpus (not validation or test!)
# This prevents data leakage

# Strategy 1: Sample proportionally from each bucket (maintains length distribution)
def sample_from_buckets(buckets, total_samples):
    """Sample words proportionally from each bucket"""
    # Calculate total words
    total_words = sum(len(words) for words in buckets.values())
    
    # Sample proportionally from each bucket
    sampled_words = []
    for length, words in buckets.items():
        proportion = len(words) / total_words
        samples_from_bucket = max(1, int(total_samples * proportion))
        sampled_words.extend(random.sample(words, min(samples_from_bucket, len(words))))
    
    return sampled_words

# Strategy 2: Sample uniformly across lengths (equal representation per length)
def sample_uniform_from_buckets(buckets, samples_per_length):
    """Sample equal number of words from each length bucket"""
    sampled_words = []
    for length, words in buckets.items():
        if len(words) > 0:
            samples = min(samples_per_length, len(words))
            sampled_words.extend(random.sample(words, samples))
    return sampled_words

# Choose sampling strategy
# Strategy 1: Proportional (maintains natural distribution)
rl_training_words = sample_from_buckets(training_buckets, min(TRAINING_SUBSET, len(training_corpus)))
if not rl_training_words or len(rl_training_words) == 0:
    # Fallback: use training_corpus directly (shuffled)
    rl_training_words = training_corpus.copy()
    random.shuffle(rl_training_words)
# Final guard: if still empty, fallback to test_words subset
if not rl_training_words or len(rl_training_words) == 0:
    rl_training_words = test_words[:1000] if len(test_words) > 0 else []


# Strategy 2: Uniform (uncomment to use instead)
# SAMPLES_PER_LENGTH = TRAINING_SUBSET // len(training_buckets)  # Distribute evenly
# rl_training_words = sample_uniform_from_buckets(training_buckets, SAMPLES_PER_LENGTH)

print(f"RL Training Configuration:")
print(f"  Training words: {len(rl_training_words)}")
print(f"  Number of episodes: {NUM_EPISODES}")
print(f"  Validation set: {len(validation_corpus)} words (for monitoring)")
print(f"  Test set: {len(test_words)} words (FINAL evaluation only)")

# Show distribution of sampled training words
training_sample_buckets = bucket_words_by_length(rl_training_words)
print(f"\nSampled training words distribution:")
for length in sorted(training_sample_buckets.keys()):
    print(f"  Length {length:2d}: {len(training_sample_buckets[length]):4d} words")

# Prepare validation subset for RL evaluation during training
val_subset = sample_from_buckets(validation_buckets, min(500, len(validation_corpus)))
if not val_subset:
    val_subset = validation_corpus[:500]
print(f"\n  Validation subset for RL monitoring: {len(val_subset)} words")



In [None]:
# Training loop with validation monitoring (to detect overfitting)
episode_rewards = []
episode_wins = []
episode_losses = []
episode_wrong_guesses = []
episode_step_counts = []  # Track steps per episode

# Validation metrics (to track overfitting)
validation_win_rates = []
validation_episodes = []

# Per-step learning statistics (optional - for analysis)
step_rewards = []  # Reward at each step (across all episodes)
step_q_updates = 0  # Count total Q-value updates

print("="*70)
print("RL TRAINING - ONLINE LEARNING (Updates after EVERY guess)")
print("="*70)
print("Key features:")
print("  ‚úì HMM probabilities recalculated after each guess")
print("  ‚úì Q-values updated immediately after each action")
print("  ‚úì Learning happens step-by-step, not episode-by-episode")
print(f"{'='*70}")
print(f"\nStarting training on {NUM_EPISODES} episodes...")
print(f"Progress every {NUM_EPISODES // 10} episodes:")

for episode in range(NUM_EPISODES):
    # Initialize bucket structure
    if episode == 0:
        training_sample_buckets = bucket_words_by_length(rl_training_words)
    # IMPORTANT: Shuffle training words periodically to prevent memorization
    if episode % 100 == 0:
        random.shuffle(rl_training_words)
        # Re-bucket if needed (optional)
        training_sample_buckets = bucket_words_by_length(rl_training_words)
    
    # Sample a random word from TRAINING set only
    # Strategy 1: Sample uniformly across all lengths (better exploration)
    # Choose a random length, then a random word from that length
    available_lengths = [l for l in training_sample_buckets.keys() if len(training_sample_buckets[l]) > 0]
    if len(available_lengths) == 0:
        available_lengths = list(training_buckets.keys())
    
    # Option A: Uniform across lengths (uncomment to use)
    # chosen_length = random.choice(available_lengths)
    # word = random.choice(training_sample_buckets[chosen_length])
    
    # Option B: Proportional to bucket size (default - maintains natural distribution)
    # Sample a random word from TRAINING set only
    word = random.choice(rl_training_words) if rl_training_words else random.choice(training_corpus)
    env = HangmanEnv(word, max_lives=6, max_guesses=25)
    state = env.get_state()
    
    episode_reward = 0
    episode_wrong = 0
    done = False
    
    # Add noise to HMM probabilities occasionally (prevent over-reliance)
    use_noisy_hmm = random.random() < 0.1  # 10% of time, add noise
    
    step_count = 0
    
    while not done:
        step_count += 1
        
        # ============================================
        # STEP 1: Get HMM probabilities for CURRENT state
        # ============================================
        # IMPORTANT: Recalculate probabilities after each guess!
        # The masked_word and guessed_letters change after each guess
        hmm_probs = get_letter_probs(hmm, state)
        
        # Optionally add noise to prevent overfitting to HMM quirks
        if use_noisy_hmm:
            # Add small Gaussian noise
            noisy_probs = {}
            for letter, prob in hmm_probs.items():
                noise = np.random.normal(0, prob * 0.1)  # 10% noise
                noisy_probs[letter] = max(0, prob + noise)
            # Renormalize
            total = sum(noisy_probs.values())
            if total > 0:
                hmm_probs = {k: v/total for k, v in noisy_probs.items()}
        
        # ============================================
        # STEP 2: Agent selects action based on current state + HMM probs
        # ============================================
        action = agent.select_action(state, hmm_probs)
        
        if action is None:
            break
        
        # ============================================
        # STEP 3: Execute action in environment
        # ============================================
        reward, next_state, done, info = env.guess_letter(action)
        
        if info['status'] == 'wrong':
            episode_wrong += 1
        
        # ============================================
        # STEP 4: Recalculate HMM probabilities for NEW state
        # ============================================
        # CRITICAL: Get updated probabilities after the guess!
        # The masked_word has changed (letters revealed), so probabilities change
        next_hmm_probs = get_letter_probs(hmm, next_state)
        
        # ============================================
        # STEP 5: UPDATE Q-VALUES IMMEDIATELY (after each guess)
        # ============================================
        # Learn from this experience RIGHT AWAY - don't wait for episode end!
        # This is online learning: update after every single action
        agent.update(
            state=state,           # Current state
            action=action,          # Action taken
            reward=reward,          # Reward received
            next_state=next_state,  # New state after guess
            hmm_probs=next_hmm_probs,  # Updated HMM probs for new state
            done=done               # Whether episode ended
        )
        
        # ============================================
        # STEP 6: Move to next state for next iteration
        # ============================================
        state = next_state
        episode_reward += reward
        
        # Optional: Track per-step learning (for debugging/analysis)
        # You can log Q-value changes here if needed
    
    # Track statistics
    episode_rewards.append(episode_reward)
    if info['status'] == 'won':
        episode_wins.append(1)
        episode_losses.append(0)
    elif info['status'] == 'lost':
        episode_wins.append(0)
        episode_losses.append(1)
    else:
        episode_wins.append(0)
        episode_losses.append(0)
    
    episode_wrong_guesses.append(episode_wrong)
    episode_step_counts.append(step_count)  # Track steps in this episode
    
    # Log learning progress
    step_q_updates += step_count  # Each step = one Q-value update
    
    # Periodic validation check (to detect overfitting)
    if (episode + 1) % (NUM_EPISODES // 5) == 0:
        # Evaluate on validation set (unseen words)
        val_wins = 0
        for val_word in val_subset[:100]:  # Sample 100 validation words
            env_val = HangmanEnv(val_word, max_lives=6, max_guesses=25)
            state_val = env_val.get_state()
            done_val = False
            
            # Use greedy policy (no exploration) for evaluation
            old_epsilon = agent.epsilon
            agent.epsilon = 0
            
            while not done_val:
                hmm_probs_val = hmm.get_probabilities_for_mask(state_val['masked_word'], state_val['guessed_letters'])
                action_val = agent.select_action(state_val, hmm_probs_val)
                if action_val is None:
                    break
                _, state_val, done_val, info_val = env_val.guess_letter(action_val)
                if done_val and info_val['status'] == 'won':
                    val_wins += 1
            
            agent.epsilon = old_epsilon
        
        val_win_rate = val_wins / min(100, len(val_subset))
        validation_win_rates.append(val_win_rate)
        validation_episodes.append(episode + 1)
    
    # Progress update
    if (episode + 1) % (NUM_EPISODES // 10) == 0:
        recent_win_rate = np.mean(episode_wins[-100:]) if len(episode_wins) >= 100 else np.mean(episode_wins)
        recent_avg_reward = np.mean(episode_rewards[-100:]) if len(episode_rewards) >= 100 else np.mean(episode_rewards)
        print(f"Episode {episode + 1}/{NUM_EPISODES} | "
              f"Train Win Rate: {recent_win_rate:.2%} | "
              f"Avg Reward: {recent_avg_reward:.2f} | "
              f"Epsilon: {agent.epsilon:.3f}")

print("\n" + "="*70)
print("TRAINING COMPLETE - ONLINE LEARNING SUMMARY")
print("="*70)
print(f"Total episodes: {NUM_EPISODES}")
print(f"Total Q-value updates (learning steps): {step_q_updates}")
print(f"Average steps per episode: {np.mean(episode_step_counts):.2f}")
print(f"Final epsilon: {agent.epsilon:.3f}")
print(f"Total states in Q-table: {len(agent.Q)}")
print(f"Average Q-table size growth: {len(agent.Q) / NUM_EPISODES:.2f} states per episode")
print(f"\n‚úì Learning happened after EVERY guess (online learning)")
print(f"‚úì HMM probabilities recalculated after each guess")
print(f"‚úì Total learning experiences: {step_q_updates}")
print("="*70)

# Check for overfitting: if training win rate >> validation win rate
if len(validation_win_rates) > 0:
    final_train_win_rate = np.mean(episode_wins[-100:]) if len(episode_wins) >= 100 else np.mean(episode_wins)
    final_val_win_rate = validation_win_rates[-1]
    print(f"\nOverfitting Check:")
    print(f"  Final Training Win Rate: {final_train_win_rate:.2%}")
    print(f"  Final Validation Win Rate: {final_val_win_rate:.2%}")
    if final_train_win_rate > final_val_win_rate * 1.5:
        print("  ‚ö†Ô∏è  WARNING: Possible overfitting detected!")
    elif abs(final_train_win_rate - final_val_win_rate) < final_train_win_rate * 0.15:
        print("  ‚úÖ Good generalization!")




## 6. Training Visualizations


In [None]:
# Plot training curves
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Reward over episodes
axes[0, 0].plot(episode_rewards, alpha=0.3, label='Episode Reward')
# Moving average
window = 50
if len(episode_rewards) >= window:
    moving_avg = np.convolve(episode_rewards, np.ones(window)/window, mode='valid')
    axes[0, 0].plot(range(window-1, len(episode_rewards)), moving_avg, 'r-', label=f'{window}-Episode Average')
axes[0, 0].set_xlabel('Episode')
axes[0, 0].set_ylabel('Reward')
axes[0, 0].set_title('Reward Over Episodes')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# Win rate over episodes
win_rates = []
for i in range(1, len(episode_wins) + 1):
    win_rates.append(np.mean(episode_wins[:i]))
axes[0, 1].plot(win_rates, label='Cumulative Win Rate')
axes[0, 1].set_xlabel('Episode')
axes[0, 1].set_ylabel('Win Rate')
axes[0, 1].set_title('Win Rate Over Episodes')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# Wrong guesses over episodes
axes[1, 0].plot(episode_wrong_guesses, alpha=0.3, label='Wrong Guesses')
if len(episode_wrong_guesses) >= window:
    moving_avg_wrong = np.convolve(episode_wrong_guesses, np.ones(window)/window, mode='valid')
    axes[1, 0].plot(range(window-1, len(episode_wrong_guesses)), moving_avg_wrong, 'r-', label=f'{window}-Episode Average')
axes[1, 0].set_xlabel('Episode')
axes[1, 0].set_ylabel('Wrong Guesses')
axes[1, 0].set_title('Wrong Guesses Over Episodes')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# Epsilon decay
epsilon_values = []
epsilon = 1.0
for _ in range(NUM_EPISODES):
    epsilon_values.append(epsilon)
    if epsilon > 0.01:
        epsilon *= 0.995
axes[1, 1].plot(epsilon_values, label='Epsilon')
axes[1, 1].set_xlabel('Episode')
axes[1, 1].set_ylabel('Epsilon')
axes[1, 1].set_title('Exploration (Epsilon) Over Episodes')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Print summary statistics
print(f"\nTraining Summary:")
print(f"Total Episodes: {NUM_EPISODES}")
print(f"Overall Win Rate: {np.mean(episode_wins):.2%}")
print(f"Average Reward: {np.mean(episode_rewards):.2f}")
print(f"Average Wrong Guesses: {np.mean(episode_wrong_guesses):.2f}")
print(f"Recent 100 Episodes Win Rate: {np.mean(episode_wins[-100:]):.2%}")


## 7. Evaluation on Test Set


## 10. Training Health Metrics Summary

**Quantitative Indicators of Healthy Training:**


In [None]:
# Comprehensive Training Health Report
print("\n" + "="*70)
print("COMPREHENSIVE TRAINING HEALTH REPORT")
print("="*70)

print("\n1. HMM Training Health:")
if 'train_perplexity' in locals() and 'val_perplexity' in locals():
    if abs(train_perplexity - val_perplexity) < train_perplexity * 0.1:
        print(f"   ‚úÖ HMM: Training ‚âà Validation perplexity (Good generalization)")
    elif val_perplexity > train_perplexity * 1.5:
        print(f"   ‚ö†Ô∏è  HMM: Overfitting detected (val >> train)")
    else:
        print(f"   ‚úì HMM: Acceptable generalization gap")
else:
    print(f"   ‚ö†Ô∏è  HMM: Perplexity not calculated - run HMM validation cell")

print("\n2. RL Training Health:")
if len(episode_rewards) > 0:
    # Check reward curve trend
    early_reward = np.mean(episode_rewards[:100]) if len(episode_rewards) >= 100 else np.mean(episode_rewards[:len(episode_rewards)//4])
    late_reward = np.mean(episode_rewards[-100:]) if len(episode_rewards) >= 100 else np.mean(episode_rewards[-len(episode_rewards)//4:])
    
    if late_reward > early_reward * 1.2:
        print(f"   ‚úÖ RL: Reward curve rising (learning effectively)")
    elif late_reward > early_reward:
        print(f"   ‚úì RL: Reward curve improving (slow but learning)")
    else:
        print(f"   ‚ö†Ô∏è  RL: Reward not improving (may need tuning)")
    
    # Check win rate trend
    early_win = np.mean(episode_wins[:100]) if len(episode_wins) >= 100 else np.mean(episode_wins[:len(episode_wins)//4])
    late_win = np.mean(episode_wins[-100:]) if len(episode_wins) >= 100 else np.mean(episode_wins[-len(episode_wins)//4:])
    
    if late_win > early_win * 1.5:
        print(f"   ‚úÖ RL: Win rate improving significantly ({early_win:.1%} ‚Üí {late_win:.1%})")
    elif late_win > early_win:
        print(f"   ‚úì RL: Win rate improving ({early_win:.1%} ‚Üí {late_win:.1%})")
    else:
        print(f"   ‚ö†Ô∏è  RL: Win rate stagnant ({late_win:.1%})")

print("\n3. Validation vs Training:")
if len(validation_win_rates) > 0:
    final_train = np.mean(episode_wins[-100:]) if len(episode_wins) >= 100 else np.mean(episode_wins)
    final_val = validation_win_rates[-1]
    gap = abs(final_train - final_val) / final_train if final_train > 0 else 0
    
    if gap < 0.15:
        print(f"   ‚úÖ Generalization: Training ‚âà Validation (gap < 15%)")
    elif gap < 0.30:
        print(f"   ‚úì Generalization: Moderate gap ({gap*100:.1f}%)")
    else:
        print(f"   ‚ö†Ô∏è  Generalization: Large gap ({gap*100:.1f}%) - possible overfitting")

print("\n4. Wrong Guesses Trend:")
if len(episode_wrong_guesses) > 100:
    early_wrong = np.mean(episode_wrong_guesses[:100])
    late_wrong = np.mean(episode_wrong_guesses[-100:])
    if late_wrong < early_wrong * 0.8:
        print(f"   ‚úÖ Wrong guesses decreasing ({early_wrong:.2f} ‚Üí {late_wrong:.2f})")
    elif late_wrong < early_wrong:
        print(f"   ‚úì Wrong guesses improving ({early_wrong:.2f} ‚Üí {late_wrong:.2f})")
    else:
        print(f"   ‚ö†Ô∏è  Wrong guesses not improving ({late_wrong:.2f})")

print("\n5. Repeated Guesses:")
if 'total_repeated' in locals():
    if total_repeated == 0:
        print(f"   ‚úÖ Zero repeated guesses (perfect!)")
    elif total_repeated < len(test_words) * 0.05:
        print(f"   ‚úì Low repeated guesses ({total_repeated}, < 5% of games)")
    else:
        print(f"   ‚ö†Ô∏è  Too many repeated guesses ({total_repeated})")

print("\n6. Overall Assessment:")
all_good = True
if 'train_perplexity' in locals() and val_perplexity > train_perplexity * 1.5:
    all_good = False
if len(validation_win_rates) > 0:
    gap = abs(np.mean(episode_wins[-100:]) - validation_win_rates[-1])
    if gap > np.mean(episode_wins[-100:]) * 0.30:
        all_good = False

if all_good and success_rate >= 0.6:
    print(f"   ‚úÖ Overall: Healthy training, good performance!")
elif success_rate >= 0.4:
    print(f"   ‚úì Overall: Acceptable performance, may benefit from more training")
else:
    print(f"   ‚ö†Ô∏è  Overall: Needs improvement - consider tuning hyperparameters")

print("="*70)


In [None]:
# Evaluation function
def evaluate_agent(agent, hmm, test_words, max_lives=6):
    """Evaluate agent on test set"""
    results = {
        'wins': 0,
        'losses': 0,
        'total_wrong_guesses': 0,
        'total_repeated_guesses': 0,
        'total_guesses': 0,
        'game_results': []
    }
    
    print(f"Evaluating on {len(test_words)} test words...")
    
    for i, word in enumerate(test_words):
        env = HangmanEnv(word, max_lives=max_lives, max_guesses=25)
        state = env.get_state()
        done = False
        wrong_guesses = 0
        repeated_guesses = 0
        
        while not done:
            # Get HMM probabilities
            hmm_probs = hmm.get_probabilities_for_mask(state['masked_word'], state['guessed_letters'])
            
            # Agent selects action (use greedy policy - no exploration)
            # Temporarily set epsilon to 0 for evaluation
            # old_epsilon = agent.epsilon
            # agent.epsilon = 0
            action = agent.select_action(state, hmm_probs)
            # agent.epsilon = old_epsilon
            
            if action is None:
                break
            
            # Execute action
            reward, next_state, done, info = env.guess_letter(action)
            
            if info['status'] == 'wrong':
                wrong_guesses += 1
            elif info['status'] == 'repeated':
                repeated_guesses += 1
            
            state = next_state
            results['total_guesses'] += 1
        
        # Record game result
        if info['status'] == 'won':
            results['wins'] += 1
        else:
            results['losses'] += 1
        
        results['total_wrong_guesses'] += wrong_guesses
        results['total_repeated_guesses'] += repeated_guesses
        
        results['game_results'].append({
            'word': word,
            'won': info['status'] == 'won',
            'wrong_guesses': wrong_guesses,
            'repeated_guesses': repeated_guesses
        })
        
        # Progress update
        if (i + 1) % 200 == 0:
            print(f"  Processed {i + 1}/{len(test_words)} words...")
    
    return results

print("Evaluation function defined!")



In [None]:
# Run evaluation
# Ensure test_words is not empty; try fallbacks if needed
if ("test_words" not in globals()) or (len(test_words) == 0):
    print("Warning: test_words empty. Trying to recover...")
    try:
        # Attempt to use raw test and reapply minimal filtering
        if "test_words_raw" in globals() and len(test_words_raw) > 0:
            test_words = [w.strip().lower() for w in test_words_raw if w and w.strip().isalpha()]
        elif "validation_corpus" in globals() and len(validation_corpus) > 0:
            test_words = validation_corpus[:min(500, len(validation_corpus))]
        elif "corpus" in globals() and len(corpus) > 0:
            test_words = corpus[-min(500, len(corpus)):]
        else:
            test_words = []
    except Exception as _e:
        print("Recovery failed:", _e)
        test_words = []

print(f"Evaluating on {len(test_words)} test words...")
evaluation_results = evaluate_agent(agent, hmm, test_words, max_lives=6)

# Calculate final score
len_test = max(1, len(test_words))
success_rate = evaluation_results['wins'] / len_test
total_wrong = evaluation_results['total_wrong_guesses']
total_repeated = evaluation_results['total_repeated_guesses']

final_score = (success_rate * 2000) - (total_wrong * 5) - (total_repeated * 2)

print("\n" + "="*60)
print("EVALUATION RESULTS")
print("="*60)
print(f"Total Test Words: {len(test_words)}")
print(f"Wins: {evaluation_results['wins']}")
print(f"Losses: {evaluation_results['losses']}")
print(f"Success Rate: {success_rate:.2%}")
print(f"\nTotal Wrong Guesses: {total_wrong}")
print(f"Average Wrong Guesses per Game: {total_wrong / max(1, len(test_words)):.2f}")
print(f"\nTotal Repeated Guesses: {total_repeated}")
print(f"Average Repeated Guesses per Game: {total_repeated / max(1, len(test_words)):.2f}")
print(f"\n{'='*60}")
print(f"FINAL SCORE: {final_score:.2f}")
print(f"{'='*60}")



## 8. Save Models


In [2]:
# Create models directory if it doesn't exist
os.makedirs('../models', exist_ok=True)

# Save HMM model
with open('../models/hmm_model.pkl', 'wb') as f:
    pickle.dump(hmm, f)
print("HMM model saved to ../models/hmm_model.pkl")

# Save RL agent
agent.save('../models/rl_agent.pkl')

print("\nAll models saved successfully!")


NameError: name 'os' is not defined

## 9. Additional Analysis (Optional)

### Analyze performance by word length


In [3]:
# Analyze performance by word length
length_stats = defaultdict(lambda: {'wins': 0, 'total': 0, 'wrong_guesses': []})

for result in evaluation_results['game_results']:
    word = result['word']
    word_len = len(word)
    length_stats[word_len]['total'] += 1
    if result['won']:
        length_stats[word_len]['wins'] += 1
    length_stats[word_len]['wrong_guesses'].append(result['wrong_guesses'])

# Enhanced visualization by word length buckets
lengths = sorted(length_stats.keys())
win_rates_by_length = [length_stats[l]['wins'] / length_stats[l]['total'] if length_stats[l]['total'] > 0 else 0 for l in lengths]
avg_wrong_by_length = [np.mean(length_stats[l]['wrong_guesses']) if len(length_stats[l]['wrong_guesses']) > 0 else 0 for l in lengths]
avg_repeated_by_length = [np.mean(length_stats[l]['repeated_guesses']) if len(length_stats[l]['repeated_guesses']) > 0 else 0 for l in lengths]
word_counts_by_length = [length_stats[l]['total'] for l in lengths]

# Create comprehensive visualization
fig, axes = plt.subplots(2, 2, figsize=(16, 12))

# 1. Win rate by length
axes[0, 0].bar(lengths, win_rates_by_length, color='steelblue', alpha=0.7)
axes[0, 0].set_xlabel('Word Length')
axes[0, 0].set_ylabel('Win Rate')
axes[0, 0].set_title('Win Rate by Word Length Bucket')
axes[0, 0].grid(True, alpha=0.3, axis='y')
axes[0, 0].axhline(y=0.5, color='r', linestyle='--', alpha=0.5, label='50% threshold')
axes[0, 0].legend()

# 2. Average wrong guesses by length
axes[0, 1].bar(lengths, avg_wrong_by_length, color='coral', alpha=0.7)
axes[0, 1].set_xlabel('Word Length')
axes[0, 1].set_ylabel('Average Wrong Guesses')
axes[0, 1].set_title('Average Wrong Guesses by Word Length Bucket')
axes[0, 1].grid(True, alpha=0.3, axis='y')

# 3. Average repeated guesses by length
axes[1, 0].bar(lengths, avg_repeated_by_length, color='orange', alpha=0.7)
axes[1, 0].set_xlabel('Word Length')
axes[1, 0].set_ylabel('Average Repeated Guesses')
axes[1, 0].set_title('Average Repeated Guesses by Word Length Bucket')
axes[1, 0].grid(True, alpha=0.3, axis='y')

# 4. Word count distribution in test set
axes[1, 1].bar(lengths, word_counts_by_length, color='green', alpha=0.7)
axes[1, 1].set_xlabel('Word Length')
axes[1, 1].set_ylabel('Number of Words')
axes[1, 1].set_title('Test Set Distribution by Word Length')
axes[1, 1].grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

# Additional analysis: Correlation between word length and performance
if len(lengths) > 1:
    correlation_wrong = np.corrcoef(lengths, avg_wrong_by_length)[0, 1]
    correlation_win = np.corrcoef(lengths, win_rates_by_length)[0, 1]
    
    print(f"\nüìä Correlation Analysis:")
    print(f"  Word Length vs Win Rate: {correlation_win:.3f}")
    print(f"  Word Length vs Wrong Guesses: {correlation_wrong:.3f}")
    
    if correlation_wrong > 0.5:
        print(f"  üí° Longer words tend to have more wrong guesses")
    if correlation_win < -0.3:
        print(f"  üí° Longer words tend to have lower win rates")

print("\nPerformance by Word Length:")
for length in sorted(length_stats.keys()):
    stats = length_stats[length]
    win_rate = stats['wins'] / stats['total']
    avg_wrong = np.mean(stats['wrong_guesses'])
    print(f"Length {length:2d}: Win Rate {win_rate:.2%}, Avg Wrong {avg_wrong:.2f} ({stats['total']} words)")


NameError: name 'defaultdict' is not defined