##### Imports

In [None]:
import sys
from pathlib import Path
import warnings

import warnings
import pandas as pd
warnings.filterwarnings('ignore')

pd.set_option('display.max_columns', 1000)
pd.set_option('display.max_rows', 1000)

import sys
# Custom library paths
sys.path.extend(['../', './scr'])

from scr.utils import set_seed
from scr.utils import read_words
from pathlib import Path
import random
from collections import Counter, defaultdict
import pickle
from tqdm import tqdm
from torch.utils.data import Dataset

from scr.utils import read_words, save_words_to_file

import pickle
from pathlib import Path
from scr.dataset import *
from scr.utils import *
# # For inference
from scr.feature_engineering import *

import gc

set_seed(42)

import torch
import torch.nn as nn
from pathlib import Path
import random

from scr.utils import print_scenarios
torch.set_float32_matmul_precision('medium')
from pathlib import Path

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# # Read and Shuffle Word List
word_list = read_words('data/words_250000_train.txt') # , limit=10000)
# word_list = read_words('data/250k.txt', limit=10000)

random.shuffle(word_list)

# Calculate Frequencies and Max Word Length
word_frequencies = calculate_word_frequencies(word_list)
char_frequency = calculate_char_frequencies(word_list)
max_word_length = max(len(word) for word in word_list)

##### Data Dir

In [None]:
NUM_STRATIFIED_SAMPLES = 100 # This will be overwritten by Papermill

NUM_WORD_SAMPLE = 1_000 # words for testing

FAST_DEV_RUN = False

MAX_EPOCH = 15

In [None]:
from pathlib import Path

# Define the base directory and the paths for training and validation parquet files
base_dataset_dir = Path("/media/sayem/510B93E12554BBD1/dataset/")
stratified_samples_dir = base_dataset_dir / str(NUM_STRATIFIED_SAMPLES)
parquet_train_path = stratified_samples_dir / 'train_parquets'
parquet_valid_path = stratified_samples_dir / 'valid_parquets'

# Create directories for train and validation parquets if they don't exist
parquet_train_path.mkdir(parents=True, exist_ok=True)
parquet_valid_path.mkdir(parents=True, exist_ok=True)

# Define and create the directory for models
models_dir = Path("/home/sayem/Desktop/Hangman/models")
models_dir.mkdir(parents=True, exist_ok=True)

# Define your output directory
# Define your output directory and logger directory
output_dir = Path("/home/sayem/Desktop/Hangman/training_outputs")
logger_dir = output_dir / "lightning_logs"

# Create the output and logger directories if they don't exist
output_dir.mkdir(parents=True, exist_ok=True)
logger_dir.mkdir(parents=True, exist_ok=True)

# Define the file path for saving the testing words
testing_words_file_path = stratified_samples_dir / "testing_words.txt"

try:
    testing_word_list = read_words(testing_words_file_path)
    print(f"Length of the testing word list: {len(testing_word_list)}")
    sampled_test_words = sample_words(testing_word_list, NUM_WORD_SAMPLE)
    print(f"Sampled {len(sampled_test_words)} unique words for testing.")
except FileNotFoundError:
    print(f"File not found: {testing_words_file_path}")

print(len(sampled_test_words))

##### Dataset Loading

In [None]:
# Create datasets directly from the saved parquet files
train_dataset = HangmanDataset(parquet_train_path)
valid_dataset = HangmanDataset(parquet_valid_path)

print(len(train_dataset))
print(len(valid_dataset))

assert len(train_dataset) > len(valid_dataset)

##### Data Module

##### Model Initialization

##### Testing

In [None]:
trained_model_file_path = models_dir / f"{NUM_STRATIFIED_SAMPLES}_trained_model.pth"
trained_model = torch.load(trained_model_file_path)
print(type(trained_model))

In [None]:
# Load the entire LSTM model object

trained_model_file_path = models_dir / f"{NUM_STRATIFIED_SAMPLES}_trained_model.pth"
trained_model = torch.load(trained_model_file_path)

# # If you want to use the model for inference
# trained_model.eval()  # Set the model to evaluation mode

from scr.game import *

word = 'may'

play_game_with_a_word(trained_model, \
    word, char_frequency, max_word_length)

In [None]:
NUM_STRATIFIED_SAMPLES

In [None]:
# Example usage
result = play_games_and_calculate_stats(trained_model, \
    sampled_test_words, char_frequency, max_word_length)

print(f"Overall Win Rate: {result['overall_win_rate']}%, Overall Average Attempts: {result['overall_avg_attempts']}")

for length, data in result["length_wise_stats"].items():
    print(f"Length {length}: Win Rate: {data['win_rate']}%, Average Attempts: {data['average_attempts_used']}")

In [None]:
performance_metrics = result['length_wise_stats']

In [None]:
from scr.data_module import *

# Initialize Data Module
initial_batch_size = 64  # Set your initial batch size

# Initialize Data Module with the required arguments
data_module = HangmanDataModule(train_dataset, valid_dataset, 
                                initial_batch_size, 
                                custom_collate_fn, 
                                performance_metrics=performance_metrics, 
                                threshold_win_rate=50)

# Iterate over the DataLoader
for i, batch in enumerate(data_module.train_dataloader()):
    # Extract information from the batch
    guessed_states = batch['guessed_states']
    guessed_letters = batch['guessed_letters']
    max_seq_len = batch['max_seq_len']
    original_seq_lengths = batch['original_seq_lengths']

    # Print the batch info for inspection
    print(f"Batch {i}:")
    print(f"Max Sequence Length: {max_seq_len}")
    print(f"Original Sequence Lengths: {original_seq_lengths}")
    print(f"Guessed States: {guessed_states}")
    print(f"Guessed Letters: {guessed_letters}")

    # Sanity check for word length
    for states in guessed_states:
        first_state = states[0]  # Assuming first state is the initial state
        word_len = len(first_state)
        print(f"Word Length from State: {word_len}")
        # assert word_len in target_word_lengths, f"Word length {word_len} \
        #     is not in target lengths {target_word_lengths}"

    if i >= 2:  # Check the first 3 batches
        break