##### Imports

In [1]:
import sys
from pathlib import Path
import warnings

import warnings
import pandas as pd
warnings.filterwarnings('ignore')

pd.set_option('display.max_columns', 1000)
pd.set_option('display.max_rows', 1000)

import sys
# Custom library paths
sys.path.extend(['../', './scr'])

from scr.utils import set_seed
from scr.utils import read_words
from pathlib import Path
import random
from collections import Counter, defaultdict
import pickle
from tqdm import tqdm
from torch.utils.data import Dataset
from scr.feature_engineering import \
    calculate_char_frequencies, calculate_word_frequencies
from scr.utils import read_words, save_words_to_file

import pickle
from pathlib import Path
from scr.dataset import *
from scr.game import *
from scr.plot_utils import *

import gc

set_seed(42)

import torch
import torch.nn as nn
from pathlib import Path
import random

from scr.utils import print_scenarios
torch.set_float32_matmul_precision('medium')

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# # Read and Shuffle Word List
word_list = read_words('data/words_250000_train.txt') # , limit=10000)
# word_list = read_words('data/250k.txt', limit=10000)
random.shuffle(word_list)

# base_dataset_dir = Path('dataset/pkl')

##### Reading Data

In [2]:
from pathlib import Path

NUM_STRATIFIED_SAMPLES = 250_000
# # # Define the base directory

base_dataset_dir = Path(f"/media/sayem/510B93E12554BBD1/dataset/{NUM_STRATIFIED_SAMPLES}")


parquet_file_path = base_dataset_dir / "HangmanData.parquet"

# Paths to the words files
train_words_file_path = base_dataset_dir / 'train_words.txt'
test_words_file_path = base_dataset_dir / 'test_words.txt'

# Read the words from the files
try:
    train_words = read_words(train_words_file_path)
    print(f"Loaded {len(train_words)} train words from {train_words_file_path}")
except FileNotFoundError:
    print(f"File not found: {train_words_file_path}")

Loaded 181840 train words from /media/sayem/510B93E12554BBD1/dataset/250000/train_words.txt


In [3]:
base_dataset_dir

PosixPath('/media/sayem/510B93E12554BBD1/dataset/250000')

In [4]:
# # For inference
from scr.feature_engineering import *

word_frequencies = calculate_word_frequencies(word_list)
char_frequency = calculate_char_frequencies(word_list)
max_word_length = max(len(word) for word in word_list)

##### Model Building

In [5]:
from scr.simple_model import SimpleLSTM
from scr.base_model import BaseModel
from pathlib import Path
import torch
from scr.feature_engineering import *

# max_word_length = 29 # TODO will remove later
# Instantiate and test the model
config = {
    'embedding_dim': 200,
    'hidden_dim': 256,
    'num_layers': 2,
    'vocab_size': 27,
    'max_word_length': max_word_length,
    'input_feature_size': 5,
    'use_embedding': True,
    'miss_linear_dim': 150,
    'lr': 0.0001,
    'optimizer_type': 'Adam'
}

model = SimpleLSTM(config)
# optimizer = model.optimizer

# Assuming 'model' is your trained model instance
model.save_model(file_path='models/model_.pth') ## TODO: _ will remove later


# Assuming the saved model file is 'models/model.pth'
model_file_path = 'models/model_.pth' ## TODO: _ will remove later

# Specify the device to load the model onto
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# # Load the model
# loaded_model = BaseModel.load_model(SimpleLSTM, model_file_path)
# # Now `loaded_model` is an instance of `SimpleLSTM` with the state and config loaded

# model = loaded_model

In [6]:
type(model.optimizer).__name__

'Adam'

##### Dataset Loading and train-test split

In [7]:
from torch.utils.data import DataLoader

from scr.dataset import HangmanDataset # , custom_collate_fn

from scr.feature_engineering import process_batch_of_games

from sklearn.model_selection import train_test_split

# Load the dataset
hangman_dataset = HangmanDataset(parquet_file_path)  # Replace with your Parquet file path

# Split the dataset
train_dataset, valid_dataset = train_test_split(hangman_dataset, \
    test_size=0.20, random_state=42)

# Now, you can use train_dataset for training and valid_dataset for validation

In [8]:
hangman_dataset[10 ** 2]

{'game_id': 9533801,
 'word': 'preacknowledgement',
 'initial_state': ['___ack____________'],
 'final_state': 'p__ack____________',
 'guessed_states': ['___ack____________',
  '___ack____________',
  '___ack____________',
  '___ack____________',
  '___ack____________',
  'p__ack____________',
  'p__ack____________'],
 'guessed_letters': ['i', 'f', 'u', 'y', 'p', 'b', 'v'],
 'game_state': 'quarterRevealed',
 'difficulty': 'hard',
 'outcome': 'win',
 'word_length': 18,
 'won': False}

In [14]:
data_loader = DataLoader(hangman_dataset, batch_size=32, \
    collate_fn=custom_collate_fn, shuffle=True)

In [18]:
for batch in data_loader:
    states = batch['guessed_states']
    guesses = batch['guessed_letters']
    max_seq_length = batch['max_seq_len']
    original_seq_lengths = batch['original_seq_lengths']

    # Preprocess the batch data
    batch_features, batch_missed_chars = process_batch_of_games(
        states, guesses, char_frequency, max_word_length, max_seq_length)

    # Move the tensors to the specified device
    batch_features = batch_features.to(device)
    batch_missed_chars = batch_missed_chars.to(device)

    # Print the shapes
    print("Shape of batch_features:", batch_features.shape)
    print("Shape of batch_missed_chars:", batch_missed_chars.shape)


    

    break

Shape of batch_features: torch.Size([32, 17, 145])
Shape of batch_missed_chars: torch.Size([32, 17, 27])


In [9]:
from scr.dataset import *

val_samples = create_validation_samples(valid_dataset)
print(val_samples[0])  # This will print the list of validation samples

(['e__e__e', 'k'], 'entente')


In [11]:
for data in val_samples:
    state, guess = data[0][0], data[0][1]
    word = data[1]
    print(state, guess)

    fets_, missed_char_ = process_batch_of_games([state],
                           [guess],
                           char_frequency,
                           max_word_length,
                           1)

    label = pad_and_reshape_labels

    print(fets_.shape)

    print(missed_char_)

    print(guess)

    break

e__e__e k
torch.Size([1, 1, 145])
tensor([[[1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
          1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]])
k


In [None]:
STOP

NameError: name 'STOP' is not defined

In [None]:
# from collections import defaultdict
# import random

# # Define the maximum number of games you want to play per epoch
# max_games_per_epoch = 1000  # Example value, adjust as needed

# # Assuming train_data_loader is your DataLoader for training data
# unique_words_set = set()
# for _, _, _, batch_full_words in val_loader:
#     unique_words_set.update(batch_full_words)

# # Group words by their lengths
# words_by_length = defaultdict(list)
# for word in unique_words_set:
#     words_by_length[len(word)].append(word)

# # Calculate the number of words to select from each length group
# num_words_per_length = max_games_per_epoch // len(words_by_length)

# selected_words = []
# for length, words in words_by_length.items():
#     random.shuffle(words)
#     selected_words.extend(words[:num_words_per_length])

# # # # Now, use 'selected_words' in place of 'unique_words' for game simulation
# # selected_words

In [None]:
# # # # Initialize DataLoader outside the loop
# batch_size = 64
# train_loader = DataLoader(train_data, batch_size=batch_size, \
#         collate_fn=processed_dataset.custom_collate_fn)
# # epochs = 10
# # # model.train()

# for epoch in range(epochs):
#     train_loader = DataLoader(train_data, batch_size=batch_size, \
#         collate_fn=processed_dataset.custom_collate_fn)

#     train_loss, train_miss_penalty = train_on_data_loader(model, train_loader, device)
#     print(f"Epoch {epoch}: Training Loss: {train_loss}, Miss Penalty: {train_miss_penalty}")

##### Untrained Model Performence

In [None]:
from scr.model_training import validate_hangman
# Now call the validate_hangman function with this set
validation_results = validate_hangman(model, val_loader, \
    char_frequency, max_word_length, device, selected_words)

In [None]:
from scr.plot_utils import *

print(f'~Untrained model performence~')
# # Print results for character-level validation
# print("Character Level Validation:")
print(f"Average Loss: {validation_results['avg_loss']}")
print(f"Miss Penalty: {validation_results['avg_miss_penalty']}")

game_simulation_results = validation_results["game_simulation"]

save_path = Path(f"plots/{NUM_STRATIFIED_SAMPLES}_untrained_model_performence_word_stats_plot.png")
plot_word_stats(game_simulation_results["length_stats"], save_path)

# # Print results for game simulation
# print("\nGame Simulation:")
print(f"Win Rate: {game_simulation_results['win_rate']}")
print(f"Average Attempts: {game_simulation_results['average_attempts']}")
print(f"Total Games: {game_simulation_results['total_games']}")
print(f"Total Wins: {game_simulation_results['total_wins']}")
print(f"Total Losses: {game_simulation_results['total_losses']}")
# print("Word Stats:", game_simulation_results["game_stats"])
# print("Word Length Stats:", game_simulation_results["length_stats"])

In [None]:
init_performance_dict = game_simulation_results["length_stats"]

# init_performance_dict = {
#     length: {"wins": 1, "losses": 1, "total_attempts": 2, "games": 2}
#     for length in range(1, max_word_length + 1)
# }

init_performance_dict

In [None]:
# batch_size = 64  # Adjust as needed
# custom_sampler = PerformanceBasedSampler(train_data, init_performance_dict)

# # # Check sampler output
# # print("Sampler output check:")
# # for i, idx in enumerate(custom_sampler):
# #     print(f"Index {i}: {idx}")
# #     if i >= 10:  # Adjust for more checks
# #         break

# # DataLoader with the custom sampler
# train_loader = DataLoader(train_data, batch_size=128, \
#     sampler=custom_sampler, collate_fn=processed_dataset.custom_collate_fn)

# # Check DataLoader output
# print("\nDataLoader output check:")
# for i, batch in enumerate(train_loader):
#     game_states, seq_lengths, missed_chars, indices, info = batch
#     # print(f"\nBatch {i} details:")
#     # print("Game States Tensor:", game_states.shape)
#     # print("Sequence Lengths:", seq_lengths.shape)
#     # print("Missed Characters Tensor:", missed_chars.shape)
#     # print("Indices Tensor:", indices)
#     # print("Information List:", info)
#     # if i >= 1:  # Adjust for more checks
#     #     break


##### Training

In [None]:
from torch.optim.lr_scheduler import ReduceLROnPlateau, CyclicLR  # Import CyclicLR or other schedulers
from torch.utils.data import DataLoader

def train_model(model, train_data, init_performance_dict, val_loader, 
                num_epochs, batch_size, selected_words, device, scheduler_type='CyclicLR'):
    best_val_loss = float('inf')
    n_epochs_stop = 10
    epochs_no_improve = 0

    # Define the scheduler
    if scheduler_type == 'CyclicLR':
        scheduler = CyclicLR(model.optimizer, base_lr=0.00001, max_lr=0.001, step_size_up=5, 
                             mode='triangular', cycle_momentum=False)
                             
    elif scheduler_type == 'ReduceLROnPlateau':
        scheduler = ReduceLROnPlateau(model.optimizer, mode='min', factor=0.1, patience=5)

    for epoch in range(num_epochs):
        # Recalculate sampler for each epoch
        custom_sampler = PerformanceBasedSampler(train_data, init_performance_dict)
        train_loader = DataLoader(train_data, batch_size=batch_size, sampler=custom_sampler, 
                                  collate_fn=processed_dataset.custom_collate_fn)

        model.train()
        train_loss, train_miss_penalty = train_on_data_loader(model, train_loader, device)

        print(f"Epoch {epoch}: Training Loss: {train_loss}, Miss Penalty: {train_miss_penalty}")

        # Validation phase
        model.eval()
        with torch.no_grad():
            validation_results = validate_hangman(model, val_loader, char_frequency, 
                                                  max_word_length, device, selected_words)
            val_loss = validation_results['avg_loss']
            val_miss_penalty = validation_results["avg_miss_penalty"]
            init_performance_dict = validation_results["game_simulation"]["length_stats"]

        # Scheduler stepping
        if scheduler_type == 'ReduceLROnPlateau':
            scheduler.step(val_loss)
        elif scheduler_type == 'CyclicLR':
            scheduler.step()

        # Early stopping check
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1
            if epochs_no_improve == n_epochs_stop:
                print("Early stopping triggered")
                break

        print(f"Epoch {epoch}: Validation Loss: {val_loss}, Validation Miss Penalty: {val_miss_penalty}")
        print(f"Win Rate: {validation_results['game_simulation']['win_rate']}")

num_epochs = 10
batch_size = 1024  # Define your batch size
model = SimpleLSTM(config)
# optimizer = model.optimizer
# Training the model
train_model(model, train_data, init_performance_dict, val_loader, num_epochs, batch_size, 
            selected_words, device, scheduler_type='ReduceLROnPlateau')

In [None]:
# from torch.optim.lr_scheduler import ReduceLROnPlateau, CyclicLR  # Import CyclicLR or other schedulers
# from torch.utils.data import DataLoader

# def train_model(model, train_data, init_performance_dict, val_loader, \
#     num_epochs, batch_size, selected_words, scheduler=None, device=device):
#     best_val_loss = float('inf')
#     n_epochs_stop = 10
#     epochs_no_improve = 0

#     for epoch in range(num_epochs):
#         custom_sampler = PerformanceBasedSampler(train_data, init_performance_dict)
#         train_loader = DataLoader(train_data, batch_size=batch_size, sampler=custom_sampler, \
#             collate_fn=processed_dataset.custom_collate_fn)

#         # train_loader = DataLoader(train_data, batch_size=batch_size, \
#         #     collate_fn=processed_dataset.custom_collate_fn)

#         model.train()
#         train_loss, train_miss_penalty = train_on_data_loader(model, \
#             train_loader, device)

#         print(f"Epoch {epoch}: Training Loss: {train_loss}, Miss Penalty: {train_miss_penalty}")

#         model.eval()
#         with torch.no_grad():
#             validation_results = validate_hangman(model, val_loader, char_frequency, \
#                 max_word_length, device, selected_words)
#             val_loss = validation_results['avg_loss']
#             val_miss_penalty = validation_results["avg_miss_penalty"]
#             init_performance_dict = validation_results["game_simulation"]["length_stats"]

#         if scheduler:
#             scheduler.step(val_loss)  # Assuming the scheduler is ReduceLROnPlateau or similar
#             # scheduler.step(val_loss)

#         if val_loss < best_val_loss:
#             best_val_loss = val_loss
#             epochs_no_improve = 0
#         else:
#             epochs_no_improve += 1
#             if epochs_no_improve == n_epochs_stop:
#                 print("Early stopping triggered")
#                 break

#         print(f"Epoch {epoch}: val loss: {val_loss}, val miss penalty: {val_miss_penalty}")
#         print(f"Win Rate: {validation_results['game_simulation']['win_rate']}")
#         print()


# num_epochs = 10
# batch_size = 64 # hyperparams

# # Example of calling train_model with a CyclicLR scheduler, disabling cycle_momentum
# scheduler = CyclicLR(model.optimizer, base_lr=0.0001, max_lr=0.01, step_size_up=5, \
#     mode='triangular', cycle_momentum=False)
    
# train_model(model, train_data, init_performance_dict, \
#     val_loader, num_epochs, batch_size, selected_words, scheduler, device)

In [None]:
# from sklearn.model_selection import KFold
# import torch
# from torch.utils.data import DataLoader, Subset

# def k_fold_cross_validate(model, dataset, k, num_epochs, optimizer, scheduler_class, device):
#     kfold = KFold(n_splits=k, shuffle=True, random_state=42)
    
#     for fold, (train_idx, val_idx) in enumerate(kfold.split(dataset)):
#         print(f"Fold {fold}")
#         train_subset = Subset(dataset, train_idx)
#         val_subset = Subset(dataset, val_idx)

#         # Initialize a new model for each fold
#         model = SimpleLSTM(config)
#         model.to(device)
#         optimizer = model.optimizer
#         scheduler = scheduler_class(optimizer)

#         train_subset = Subset(dataset, train_idx)
#         val_subset = Subset(dataset, val_idx)
#         train_model(model, train_subset, val_subset, num_epochs, optimizer, scheduler, device)


#         train_model(model, train_subset, val_subset, num_epochs, optimizer, scheduler, device)

# # Usage example
# num_epochs = 10
# batch_size = 64
# scheduler_class = torch.optim.lr_scheduler.ReduceLROnPlateau  # Example scheduler class
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# k_fold_cross_validate(model, processed_dataset, \
#     5, num_epochs, optimizer, scheduler_class, device)

##### Trained Model Performence

In [None]:
val_loader = processed_dataset.create_val_loader(val_data)


# Call the validation function
validation_results = validate_hangman(model, val_loader, \
    char_frequency, max_word_length, device)

# Access the results
character_level_results = validation_results["character_level"]
game_simulation_results = validation_results["game_simulation"]

from scr.plot_utils import *

save_path = Path(f"plots/{NUM_STRATIFIED_SAMPLES}_trained_model_performence_word_stats_plot.png")
plot_word_stats(game_simulation_results["length_stats"], save_path)

# # Print results for character-level validation
# print("Character Level Validation:")
print(f"Average Loss: {character_level_results['avg_loss']}")
print(f"Miss Penalty: {character_level_results['avg_miss_penalty']}")
# print(f"Character Accuracy: {character_level_results['char_accuracy']}")
# print(f"Word Accuracy: {character_level_results['word_accuracy']}")
# print("Word Statistics:", character_level_results["word_stats"])  # Updated key

# # Print results for game simulation
# print("\nGame Simulation:")
print(f"Win Rate: {game_simulation_results['win_rate']}")
print(f"Average Attempts: {game_simulation_results['average_attempts']}")
print(f"Total Games: {game_simulation_results['total_games']}")
print(f"Total Wins: {game_simulation_results['total_wins']}")
print(f"Total Losses: {game_simulation_results['total_losses']}")
# print("Word Stats:", game_simulation_results["game_stats"])
# print("Word Length Stats:", game_simulation_results["length_stats"])

In [None]:
print("Word Stats:", game_simulation_results["game_stats"])

In [None]:
STOP

In [None]:
batch_features_tensor.shape

In [None]:
len(val_loader)

In [None]:
batch_features_tensor.shape

In [None]:
STOP

In [None]:
STOP

In [None]:
def train_epoch(model, data_loader, optimizer, device):
    model.train()  # Set the model to training mode
    total_loss = 0
    total_miss_penalty = 0

    for batch in data_loader:
        if batch[0] is None:
            continue  # Skip empty batches

        game_states_batch, lengths_batch, missed_chars_batch, labels_batch = batch
        game_states_batch, lengths_batch, missed_chars_batch = \
            game_states_batch.to(device), lengths_batch, missed_chars_batch.to(device)

        optimizer.zero_grad()

        outputs = model(game_states_batch, lengths_batch, missed_chars_batch)
        model_output_shape = outputs.shape
        reshaped_labels = pad_and_reshape_labels(labels_batch, model_output_shape).to(device)

        loss, miss_penalty = model.calculate_loss(outputs, reshaped_labels, \
                                                  lengths_batch, missed_chars_batch, 27)

        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_miss_penalty += miss_penalty.item()  # Accumulate miss penalty

    avg_loss = total_loss / len(data_loader)
    avg_miss_penalty = total_miss_penalty / len(data_loader)
    return avg_loss, avg_miss_penalty  # Return average loss and miss penalty



def validate_epoch(model, data_loader, device):
    model.eval()  # Set the model to evaluation mode
    total_loss = 0
    total_miss_penalty = 0

    with torch.no_grad():
        for batch in data_loader:
            if batch[0] is None:
                continue  # Skip empty batches


            game_states_batch, lengths_batch, missed_chars_batch, labels_batch = batch
            game_states_batch, lengths_batch, missed_chars_batch = \
                game_states_batch.to(device), lengths_batch, missed_chars_batch.to(device)

            outputs = model(game_states_batch, lengths_batch, missed_chars_batch)
            model_output_shape = outputs.shape
            reshaped_labels = pad_and_reshape_labels(labels_batch, model_output_shape).to(device)

            loss, miss_penalty = model.calculate_loss(outputs, reshaped_labels, \
                                                      lengths_batch, missed_chars_batch, 27)
            total_loss += loss.item()

            total_miss_penalty += miss_penalty.item()  # Accumulate miss penalty

    avg_loss = total_loss / len(data_loader)
    avg_miss_penalty = total_miss_penalty / len(data_loader)
    return avg_loss, avg_miss_penalty  # 

# for epoch in range(num_epochs):
#     train_loss = train_epoch(model, train_loader, optimizer, device)
#     val_loss = validate_epoch(model, val_loader, device)
    
#     print(f"Epoch {epoch}: Training Loss: {train_loss}, Validation Loss: {val_loss}")
#     # You can add code to save model checkpoints if needed

In [None]:
from sklearn.model_selection import KFold
import torch
from torch.utils.data import DataLoader, Subset

def train_model(model, train_data, val_data, num_epochs, optimizer, scheduler, device):
    train_loader = DataLoader(train_data, batch_size=batch_size, \
        collate_fn=custom_collate_fn)
    val_loader = DataLoader(val_data, batch_size=batch_size, \
        collate_fn=custom_collate_fn)

    for epoch in range(num_epochs):
        train_loss, train_miss_penalty = train_epoch(model, train_loader, \
            optimizer, device)
        val_loss, val_miss_penalty = validate_epoch(model, \
            val_loader, device)

        scheduler.step(val_loss)  # Adjust LR based on validation loss

        print(f"Epoch {epoch}: Training Loss: {train_loss}, Miss Penalty: \
            {train_miss_penalty}, Validation Loss: {val_loss}, Validation Miss Penalty: {val_miss_penalty}")
        # Save model checkpoints if needed

def k_fold_cross_validate(model, dataset, k, num_epochs, optimizer, scheduler_class, device):
    kfold = KFold(n_splits=k, shuffle=True, random_state=42)
    
    for fold, (train_idx, val_idx) in enumerate(kfold.split(dataset)):
        print(f"Fold {fold}")
        train_subset = Subset(dataset, train_idx)
        val_subset = Subset(dataset, val_idx)

        # Initialize a new model for each fold
        model = SimpleLSTM(config)
        model.to(device)
        optimizer = model.optimizer
        scheduler = scheduler_class(optimizer)

        train_subset = Subset(dataset, train_idx)
        val_subset = Subset(dataset, val_idx)
        train_model(model, train_subset, val_subset, num_epochs, optimizer, scheduler, device)


        train_model(model, train_subset, val_subset, num_epochs, optimizer, scheduler, device)

# Usage example
num_epochs = 10
batch_size = 64
scheduler_class = torch.optim.lr_scheduler.ReduceLROnPlateau  # Example scheduler class
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

k_fold_cross_validate(model, processed_dataset, \
    5, num_epochs, optimizer, scheduler_class, device)

In [None]:
STOP

In [None]:
from scr.dataset import ProcessedHangmanDataset
from scr.custom_sampler import PerformanceBasedSampler


print(f"Total number of data points in the dataset: {len(processed_dataset)}")

# Initialize the sampler
sampler = PerformanceBasedSampler(
    processed_dataset, 
    performance_metrics,
    max_word_length=max_word_length
)

print(f"Sampler created with {len(sampler)} indices.")

In [None]:
sample_indices = list(sampler)[:10]  # Get the first 10 sampled indices
print("Sampled indices:", sample_indices)

for idx in sample_indices:
    data_point = processed_dataset[idx]
    print(f"Data at index {idx}: {data_point}")
    break

In [None]:
from scr.utils import print_scenarios

def process_pkl_files(base_dir):
    pkl_list = []

    # Iterate over all batch directories
    for batch_dir in sorted(base_dir.iterdir(), key=lambda x: int(x.name) if x.name.isdigit() else float('inf')):
        if batch_dir.is_dir():
            # List all .pkl files in the current batch directory
            pkl_files = list(batch_dir.glob("*.pkl"))

            for pkl_file in pkl_files:
                try:
                    with open(pkl_file, 'rb') as file:
                        game_data = pickle.load(file)
                except IOError as e:
                    print(f"Error reading file {pkl_file}: {e}")
                    continue

                # Processing each pickle file
                pkl_list.extend(process_pkl_file(pkl_file, game_data))

    return pkl_list

def process_pkl_file(pkl_file, game_data):
    file_scenarios = []
    for data in game_data:
        game_won, guesses = data
        word, initial_state, difficulty, outcome = extract_info_from_filename(pkl_file)
        
        # Create a scenario dictionary for each data tuple
        scenario = {
            'word': word,
            'difficulty': difficulty,
            'outcome': outcome,
            'data': (game_won, guesses)
        }
        file_scenarios.append((pkl_file, scenario))  # Add scenario to the list

    return file_scenarios

def extract_info_from_filename(pkl_file):
    parts = pkl_file.stem.split('_from_')
    word_and_state = parts[0].split('_')
    word = '_'.join(word_and_state[:-1])
    initial_state = word_and_state[-1]
    difficulty, outcome = parts[1].split('_')[-2:]
    return word, initial_state, difficulty, outcome

# def print_scenarios(scenarios):
#     # Assuming this function is defined elsewhere
#     pass

# Process all pickle files
pkl_list = process_pkl_files(base_dataset_dir)

# Accessing an individual pickle file's content by index
index_to_access = 0  # Change this index to access different files
if index_to_access < len(pkl_list):
    file_path, scenario = pkl_list[index_to_access]
    print(f"Contents of {file_path}:")
    print_scenarios([scenario])  # Wrap scenario in a list for the function
else:
    print(f"No pickle file at index {index_to_access}")

No pickle file at index 0



In [None]:
pkl_list = []

# Iterate over all batch directories
for batch_dir in sorted(base_dataset_dir.iterdir(), \
    key=lambda x: int(x.name) if x.name.isdigit() else float('inf')):
    if batch_dir.is_dir():
        # List all .pkl files in the current batch directory
        pkl_files = list(batch_dir.glob("*.pkl"))

        for pkl_file in pkl_files:
            with open(pkl_file, 'rb') as file:
                game_data = pickle.load(file)
                # Extract information from file name
                parts = pkl_file.stem.split('_from_')
                word_and_statet = parts[0].split('_')
                word = '_'.join(word_and_state[:-1])
                initial_state = word_and_state[-1]
                difficulty, outcome = parts[1].split('_')[-2:]

                # Assuming game_data is a list of tuples (game_won, guesses)
                for data in game_data:
                    game_won, guesses = data
                    # Create a scenario dictionary for each data tuple
                    scenario = {
                        'word': word,
                        'difficulty': difficulty,
                        'outcome': outcome,
                        'data': (game_won, guesses)
                    }
                    pkl_list.append((pkl_file, scenario))  # Add scenario to the list

# Accessing an individual pickle file's content by index
index_to_access = 0  # Change this index to access different files
if index_to_access < len(pkl_list):
    file_path, scenario = pkl_list[index_to_access]
    print(f"Contents of {file_path}:")
    print_scenarios([scenario])  # Wrap scenario in a list for the function
else:
    print(f"No pickle file at index {index_to_access}")