##### Imports

In [1]:
import sys
from pathlib import Path
import warnings

import warnings
import pandas as pd
warnings.filterwarnings('ignore')

pd.set_option('display.max_columns', 1000)
pd.set_option('display.max_rows', 1000)

import sys
# Custom library paths
sys.path.extend(['../', './scr'])

from scr.utils import set_seed
from scr.utils import read_words
from pathlib import Path
import random
from collections import Counter, defaultdict
import pickle
from tqdm import tqdm
from torch.utils.data import Dataset
from scr.feature_engineering import \
    calculate_char_frequencies, calculate_word_frequencies
from scr.utils import read_words, save_words_to_file

import pickle
from pathlib import Path
from scr.dataset import *
from scr.game import *
from scr.plot_utils import *

import gc

set_seed(42)

import torch
import torch.nn as nn
from pathlib import Path
import random

from scr.utils import print_scenarios
torch.set_float32_matmul_precision('medium')

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# # Read and Shuffle Word List
word_list = read_words('data/words_250000_train.txt') # , limit=10000)
# word_list = read_words('data/250k.txt', limit=10000)
random.shuffle(word_list)

# base_dataset_dir = Path('dataset/pkl')

##### Reading Data

In [2]:
from pathlib import Path

NUM_STRATIFIED_SAMPLES = 50000
# # # Define the base directory
# base_dataset_dir = Path(f'./dataset/{NUM_STRATIFIED_SAMPLES}')
base_dataset_dir = Path(f"/media/sayem/510B93E12554BBD1/dataset/{NUM_STRATIFIED_SAMPLES}")

# Ensuring the base directory and 'pkl' subdirectory exist
base_dataset_dir.mkdir(parents=True, exist_ok=True)
pkls_dir = base_dataset_dir / 'pkl'
pkls_dir.mkdir(parents=True, exist_ok=True)

# Paths to the words files
train_words_file_path = base_dataset_dir / 'train_words.txt'
test_words_file_path = base_dataset_dir / 'test_words.txt'

# Read the words from the files
try:
    train_words = read_words(train_words_file_path)
    print(f"Loaded {len(train_words)} train words from {train_words_file_path}")
except FileNotFoundError:
    print(f"File not found: {train_words_file_path}")

Loaded 181840 train words from /media/sayem/510B93E12554BBD1/dataset/50000/train_words.txt


In [3]:
base_dataset_dir

PosixPath('/media/sayem/510B93E12554BBD1/dataset/50000')

In [4]:
# # For inference
from scr.feature_engineering import *

word_frequencies = calculate_word_frequencies(word_list)
char_frequency = calculate_char_frequencies(word_list)
max_word_length = max(len(word) for word in word_list)

##### Model Building

In [5]:
from scr.simple_model import SimpleLSTM
from scr.base_model import BaseModel
from pathlib import Path
import torch
from scr.feature_engineering import *

# max_word_length = 29 # TODO will remove later
# Instantiate and test the model
config = {
    'embedding_dim': 200,
    'hidden_dim': 256,
    'num_layers': 2,
    'vocab_size': 27,
    'max_word_length': max_word_length,
    'input_feature_size': 5,
    'use_embedding': True,
    'miss_linear_dim': 50,
    'lr': 0.001
}

model = SimpleLSTM(config)
optimizer = model.optimizer

# Assuming 'model' is your trained model instance
model.save_model(file_path='models/model.pth') ## TODO: _ will remove later


# Assuming the saved model file is 'models/model.pth'
model_file_path = 'models/model.pth' ## TODO: _ will remove later

# Specify the device to load the model onto
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# # Load the model
# loaded_model = BaseModel.load_model(SimpleLSTM, model_file_path)
# # Now `loaded_model` is an instance of `SimpleLSTM` with the state and config loaded

# model = loaded_model

##### Dataset Loading and train-test split

In [6]:
pkls_dir

PosixPath('/media/sayem/510B93E12554BBD1/dataset/50000/pkl')

In [7]:
from scr.dataset import *
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
import random
from scr.model_training import *

# # Load the dataset
processed_dataset = ProcessedHangmanDataset(pkls_dir, \
    char_frequency, max_word_length, files_limit=None)
    
print(len(processed_dataset))

2100726


In [8]:
processed_dataset[1000]

(['___b____e', 'a_aba___e', 'anaba__ne', 'anaba_ine'],
 ['a', 'n', 'i', 's'],
 {'word': 'anabasine',
  'initial_state': '___b_____',
  'game_state': 'early',
  'difficulty': 'easy',
  'outcome': 'win',
  'word_length': '9'})

In [9]:
# Convert PyTorch dataset to a list for train_test_split
dataset_list = [processed_dataset[i] for i in range(len(processed_dataset))]

# Perform an 80%-20% train-test split
train_data, val_data = train_test_split(dataset_list, \
    test_size=0.2, random_state=42)

val_loader = processed_dataset.create_val_loader(val_data)

del dataset_list

In [10]:
from collections import defaultdict
import random

# Define the maximum number of games you want to play per epoch
max_games_per_epoch = 1000  # Example value, adjust as needed

# Assuming train_data_loader is your DataLoader for training data
unique_words_set = set()
for _, _, _, batch_full_words in val_loader:
    unique_words_set.update(batch_full_words)

# Group words by their lengths
words_by_length = defaultdict(list)
for word in unique_words_set:
    words_by_length[len(word)].append(word)

# Calculate the number of words to select from each length group
num_words_per_length = max_games_per_epoch // len(words_by_length)

selected_words = []
for length, words in words_by_length.items():
    random.shuffle(words)
    selected_words.extend(words[:num_words_per_length])

# # # Now, use 'selected_words' in place of 'unique_words' for game simulation
# selected_words

In [11]:
# # # Initialize DataLoader outside the loop
# batch_size = 64
# # train_loader = DataLoader(train_data, \
# #     batch_size=batch_size, \
# #         collate_fn=processed_dataset.custom_collate_fn)
# epochs = 10
# # model.train()

# for epoch in range(epochs):
#     train_loader = DataLoader(train_data, batch_size=batch_size, \
#         collate_fn=processed_dataset.custom_collate_fn)

#     train_loss, train_miss_penalty = train_on_data_loader(model, train_loader, device, optimizer)
#     print(f"Epoch {epoch}: Training Loss: {train_loss}, Miss Penalty: {train_miss_penalty}")

##### Untrained Model Performence

In [12]:
from scr.model_training import validate_hangman
# Now call the validate_hangman function with this set
validation_results = validate_hangman(model, val_loader, \
    char_frequency, max_word_length, device, selected_words)

In [13]:
from scr.plot_utils import *

print(f'~Untrained model performence~')
# # Print results for character-level validation
# print("Character Level Validation:")
print(f"Average Loss: {validation_results['avg_loss']}")
print(f"Miss Penalty: {validation_results['avg_miss_penalty']}")


game_simulation_results = validation_results["game_simulation"]

save_path = Path(f"plots/{NUM_STRATIFIED_SAMPLES}_untrained_model_performence_word_stats_plot.png")
plot_word_stats(game_simulation_results["length_stats"], save_path)

# # Print results for game simulation
# print("\nGame Simulation:")
print(f"Win Rate: {game_simulation_results['win_rate']}")
print(f"Average Attempts: {game_simulation_results['average_attempts']}")
print(f"Total Games: {game_simulation_results['total_games']}")
print(f"Total Wins: {game_simulation_results['total_wins']}")
print(f"Total Losses: {game_simulation_results['total_losses']}")
# print("Word Stats:", game_simulation_results["game_stats"])
# print("Word Length Stats:", game_simulation_results["length_stats"])

~Untrained model performence~
Average Loss: 18.828806423899568
Miss Penalty: 0.41289582541733083


Win Rate: 0.054317548746518104
Average Attempts: 5.908077994428969
Total Games: 718
Total Wins: 39
Total Losses: 679


In [14]:
init_performance_dict = game_simulation_results["length_stats"]

# init_performance_dict = {
#     length: {"wins": 1, "losses": 1, "total_attempts": 2, "games": 2}
#     for length in range(1, max_word_length + 1)
# }

init_performance_dict

{5: {'wins': 1, 'losses': 34, 'total_attempts': 208, 'games': 35},
 21: {'wins': 2, 'losses': 20, 'total_attempts': 129, 'games': 22},
 15: {'wins': 1, 'losses': 34, 'total_attempts': 209, 'games': 35},
 22: {'wins': 3, 'losses': 10, 'total_attempts': 72, 'games': 13},
 11: {'wins': 2, 'losses': 33, 'total_attempts': 207, 'games': 35},
 19: {'wins': 3, 'losses': 32, 'total_attempts': 206, 'games': 35},
 3: {'wins': 0, 'losses': 35, 'total_attempts': 210, 'games': 35},
 24: {'wins': 2, 'losses': 3, 'total_attempts': 27, 'games': 5},
 9: {'wins': 0, 'losses': 35, 'total_attempts': 210, 'games': 35},
 18: {'wins': 3, 'losses': 32, 'total_attempts': 203, 'games': 35},
 20: {'wins': 6, 'losses': 29, 'total_attempts': 202, 'games': 35},
 2: {'wins': 1, 'losses': 34, 'total_attempts': 208, 'games': 35},
 17: {'wins': 1, 'losses': 34, 'total_attempts': 208, 'games': 35},
 14: {'wins': 0, 'losses': 35, 'total_attempts': 210, 'games': 35},
 12: {'wins': 1, 'losses': 34, 'total_attempts': 209, 'g

In [15]:
# # # # Initialize DataLoader outside the loop
# # from scr.dataset import *
# batch_size = 256

# custom_sampler = PerformanceBasedSampler(train_data, batch_size, init_performance_dict)
# train_loader = DataLoader(train_data, sampler=custom_sampler, \
#     collate_fn=processed_dataset.custom_collate_fn)

# # # train_loader = DataLoader(train_data, batch_size=batch_size, \
# # #     collate_fn=processed_dataset.custom_collate_fn)

In [16]:
# # %%capture

# for i, batch in enumerate(train_loader):
#     # print(batch)
#     # break
#     # # if i == 10:
#     # #     break
#     pass

##### Training

In [17]:
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader

def train_model(model, train_data, init_performance_dict, \
    val_loader, num_epochs, optimizer, batch_size, \
    unique_words_set, scheduler=None, device=device):
    
    best_val_loss = float('inf')
    n_epochs_stop = 10

    for epoch in range(num_epochs):
        # Initialize DataLoader outside the loop
        custom_sampler = PerformanceBasedSampler(train_data, batch_size, init_performance_dict)
        train_loader = DataLoader(train_data, sampler=custom_sampler, \
            collate_fn=processed_dataset.custom_collate_fn)

        # # Update only the sampler for each epoch
        # custom_sampler = PerformanceBasedSampler(train_data, batch_size, \
        #     init_performance_dict)
        # train_loader.sampler = custom_sampler

        model.train()
        train_loss, train_miss_penalty = train_on_data_loader(model, \
            train_loader, device, optimizer)
        print(f"Epoch {epoch}: Training Loss: {train_loss}, \
Miss Penalty: {train_miss_penalty}")

        model.eval()
        with torch.no_grad():
            validation_results = validate_hangman(model, val_loader, \
                char_frequency, max_word_length, device, unique_words_set)
            val_loss = validation_results['avg_loss']
            val_miss_penalty = validation_results["avg_miss_penalty"]
            init_performance_dict = validation_results["game_simulation"]["length_stats"]

        if scheduler:
            scheduler.step(val_loss)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1
            if epochs_no_improve == n_epochs_stop:
                print("Early stopping triggered")
                break

        print(f"Epoch {epoch}: val loss: {val_loss}, val miss penalty {val_miss_penalty}")
        print(f"Win Rate: {validation_results['game_simulation']['win_rate']}")
        print()

# Rest of the code to call train_model

# # Assuming model, init_train_data, val_loader, num_epochs, optimizer, batch_size are already defined
# init_performance_dict = {length: {"wins": 0, "losses": 0, "total_attempts": 0, "games": 0} \
#     for length in range(1, max_word_length + 1)}

scheduler = ReduceLROnPlateau(optimizer, 'min', patience=5, factor=0.1, verbose=True)

num_epochs = 10
batch_size = 256 # hyperparams

train_model(model, train_data, init_performance_dict, \
    val_loader, num_epochs, optimizer, batch_size, selected_words, scheduler)



Epoch 0: Training Loss: 3.855369137903791, Miss Penalty: 0.03661621908795757
Epoch 0: val loss: 16.51804387825212, val miss penalty 0.03594462970351736
Win Rate: 0.16991643454038996

Epoch 1: Training Loss: 3.805361223895434, Miss Penalty: 0.03633534110903976
Epoch 1: val loss: 18.221619068958095, val miss penalty 0.029768001999721787
Win Rate: 0.15598885793871867

Epoch 2: Training Loss: 3.7744552142433663, Miss Penalty: 0.03614896926103736
Epoch 2: val loss: 21.366754138046083, val miss penalty 0.03709219447231691
Win Rate: 0.22562674094707522

Epoch 3: Training Loss: 3.7642798807754883, Miss Penalty: 0.03612618157926593
Epoch 3: val loss: 31.96382179610007, val miss penalty 0.02984225980421366
Win Rate: 0.201949860724234

Epoch 4: Training Loss: 3.7651465017869783, Miss Penalty: 0.03619539079153814
Epoch 4: val loss: 30.995442114458694, val miss penalty 0.0347942233198908
Win Rate: 0.1977715877437326

Epoch 5: Training Loss: 3.765891452828049, Miss Penalty: 0.03619203050981422
Epoch

In [None]:
# from sklearn.model_selection import KFold
# import torch
# from torch.utils.data import DataLoader, Subset

# def train_model(model, train_data, val_data, num_epochs, optimizer, scheduler, device):
#     train_loader = DataLoader(train_data, batch_size=batch_size, \
#         collate_fn=custom_collate_fn)
#     val_loader = DataLoader(val_data, batch_size=batch_size, \
#         collate_fn=custom_collate_fn)

#     for epoch in range(num_epochs):
#         train_loss, train_miss_penalty = train_epoch(model, train_loader, \
#             optimizer, device)
#         val_loss, val_miss_penalty = validate_epoch(model, \
#             val_loader, device)

#         scheduler.step(val_loss)  # Adjust LR based on validation loss

#         print(f"Epoch {epoch}: Training Loss: {train_loss}, Miss Penalty: \
#             {train_miss_penalty}, Validation Loss: {val_loss}, Validation Miss Penalty: {val_miss_penalty}")
#         # Save model checkpoints if needed

# def k_fold_cross_validate(model, dataset, k, num_epochs, optimizer, scheduler_class, device):
#     kfold = KFold(n_splits=k, shuffle=True, random_state=42)
    
#     for fold, (train_idx, val_idx) in enumerate(kfold.split(dataset)):
#         print(f"Fold {fold}")
#         train_subset = Subset(dataset, train_idx)
#         val_subset = Subset(dataset, val_idx)

#         # Initialize a new model for each fold
#         model = SimpleLSTM(config)
#         model.to(device)
#         optimizer = model.optimizer
#         scheduler = scheduler_class(optimizer)

#         train_subset = Subset(dataset, train_idx)
#         val_subset = Subset(dataset, val_idx)
#         train_model(model, train_subset, val_subset, num_epochs, optimizer, scheduler, device)


#         train_model(model, train_subset, val_subset, num_epochs, optimizer, scheduler, device)

# # Usage example
# num_epochs = 10
# batch_size = 64
# scheduler_class = torch.optim.lr_scheduler.ReduceLROnPlateau  # Example scheduler class
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# k_fold_cross_validate(model, processed_dataset, \
#     5, num_epochs, optimizer, scheduler_class, device)

##### Trained Model Performence

In [None]:
val_loader = processed_dataset.create_val_loader(val_data)


# Call the validation function
validation_results = validate_hangman(model, val_loader, \
    char_frequency, max_word_length, device)

# Access the results
character_level_results = validation_results["character_level"]
game_simulation_results = validation_results["game_simulation"]

from scr.plot_utils import *

save_path = Path(f"plots/{NUM_STRATIFIED_SAMPLES}_trained_model_performence_word_stats_plot.png")
plot_word_stats(game_simulation_results["length_stats"], save_path)

# # Print results for character-level validation
# print("Character Level Validation:")
print(f"Average Loss: {character_level_results['avg_loss']}")
print(f"Miss Penalty: {character_level_results['avg_miss_penalty']}")
# print(f"Character Accuracy: {character_level_results['char_accuracy']}")
# print(f"Word Accuracy: {character_level_results['word_accuracy']}")
# print("Word Statistics:", character_level_results["word_stats"])  # Updated key

# # Print results for game simulation
# print("\nGame Simulation:")
print(f"Win Rate: {game_simulation_results['win_rate']}")
print(f"Average Attempts: {game_simulation_results['average_attempts']}")
print(f"Total Games: {game_simulation_results['total_games']}")
print(f"Total Wins: {game_simulation_results['total_wins']}")
print(f"Total Losses: {game_simulation_results['total_losses']}")
# print("Word Stats:", game_simulation_results["game_stats"])
# print("Word Length Stats:", game_simulation_results["length_stats"])

In [None]:
print("Word Stats:", game_simulation_results["game_stats"])

In [None]:
STOP

In [None]:
batch_features_tensor.shape

In [None]:
len(val_loader)

In [None]:
batch_features_tensor.shape

In [None]:
STOP

In [None]:
STOP

In [None]:
def train_epoch(model, data_loader, optimizer, device):
    model.train()  # Set the model to training mode
    total_loss = 0
    total_miss_penalty = 0

    for batch in data_loader:
        if batch[0] is None:
            continue  # Skip empty batches

        game_states_batch, lengths_batch, missed_chars_batch, labels_batch = batch
        game_states_batch, lengths_batch, missed_chars_batch = \
            game_states_batch.to(device), lengths_batch, missed_chars_batch.to(device)

        optimizer.zero_grad()

        outputs = model(game_states_batch, lengths_batch, missed_chars_batch)
        model_output_shape = outputs.shape
        reshaped_labels = pad_and_reshape_labels(labels_batch, model_output_shape).to(device)

        loss, miss_penalty = model.calculate_loss(outputs, reshaped_labels, \
                                                  lengths_batch, missed_chars_batch, 27)

        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_miss_penalty += miss_penalty.item()  # Accumulate miss penalty

    avg_loss = total_loss / len(data_loader)
    avg_miss_penalty = total_miss_penalty / len(data_loader)
    return avg_loss, avg_miss_penalty  # Return average loss and miss penalty



def validate_epoch(model, data_loader, device):
    model.eval()  # Set the model to evaluation mode
    total_loss = 0
    total_miss_penalty = 0

    with torch.no_grad():
        for batch in data_loader:
            if batch[0] is None:
                continue  # Skip empty batches


            game_states_batch, lengths_batch, missed_chars_batch, labels_batch = batch
            game_states_batch, lengths_batch, missed_chars_batch = \
                game_states_batch.to(device), lengths_batch, missed_chars_batch.to(device)

            outputs = model(game_states_batch, lengths_batch, missed_chars_batch)
            model_output_shape = outputs.shape
            reshaped_labels = pad_and_reshape_labels(labels_batch, model_output_shape).to(device)

            loss, miss_penalty = model.calculate_loss(outputs, reshaped_labels, \
                                                      lengths_batch, missed_chars_batch, 27)
            total_loss += loss.item()

            total_miss_penalty += miss_penalty.item()  # Accumulate miss penalty

    avg_loss = total_loss / len(data_loader)
    avg_miss_penalty = total_miss_penalty / len(data_loader)
    return avg_loss, avg_miss_penalty  # 

# for epoch in range(num_epochs):
#     train_loss = train_epoch(model, train_loader, optimizer, device)
#     val_loss = validate_epoch(model, val_loader, device)
    
#     print(f"Epoch {epoch}: Training Loss: {train_loss}, Validation Loss: {val_loss}")
#     # You can add code to save model checkpoints if needed

In [None]:
from sklearn.model_selection import KFold
import torch
from torch.utils.data import DataLoader, Subset

def train_model(model, train_data, val_data, num_epochs, optimizer, scheduler, device):
    train_loader = DataLoader(train_data, batch_size=batch_size, \
        collate_fn=custom_collate_fn)
    val_loader = DataLoader(val_data, batch_size=batch_size, \
        collate_fn=custom_collate_fn)

    for epoch in range(num_epochs):
        train_loss, train_miss_penalty = train_epoch(model, train_loader, \
            optimizer, device)
        val_loss, val_miss_penalty = validate_epoch(model, \
            val_loader, device)

        scheduler.step(val_loss)  # Adjust LR based on validation loss

        print(f"Epoch {epoch}: Training Loss: {train_loss}, Miss Penalty: \
            {train_miss_penalty}, Validation Loss: {val_loss}, Validation Miss Penalty: {val_miss_penalty}")
        # Save model checkpoints if needed

def k_fold_cross_validate(model, dataset, k, num_epochs, optimizer, scheduler_class, device):
    kfold = KFold(n_splits=k, shuffle=True, random_state=42)
    
    for fold, (train_idx, val_idx) in enumerate(kfold.split(dataset)):
        print(f"Fold {fold}")
        train_subset = Subset(dataset, train_idx)
        val_subset = Subset(dataset, val_idx)

        # Initialize a new model for each fold
        model = SimpleLSTM(config)
        model.to(device)
        optimizer = model.optimizer
        scheduler = scheduler_class(optimizer)

        train_subset = Subset(dataset, train_idx)
        val_subset = Subset(dataset, val_idx)
        train_model(model, train_subset, val_subset, num_epochs, optimizer, scheduler, device)


        train_model(model, train_subset, val_subset, num_epochs, optimizer, scheduler, device)

# Usage example
num_epochs = 10
batch_size = 64
scheduler_class = torch.optim.lr_scheduler.ReduceLROnPlateau  # Example scheduler class
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

k_fold_cross_validate(model, processed_dataset, \
    5, num_epochs, optimizer, scheduler_class, device)

In [None]:
STOP

In [None]:
from scr.dataset import ProcessedHangmanDataset
from scr.custom_sampler import PerformanceBasedSampler


print(f"Total number of data points in the dataset: {len(processed_dataset)}")

# Initialize the sampler
sampler = PerformanceBasedSampler(
    processed_dataset, 
    performance_metrics,
    max_word_length=max_word_length
)

print(f"Sampler created with {len(sampler)} indices.")

In [None]:
sample_indices = list(sampler)[:10]  # Get the first 10 sampled indices
print("Sampled indices:", sample_indices)

for idx in sample_indices:
    data_point = processed_dataset[idx]
    print(f"Data at index {idx}: {data_point}")
    break

In [None]:
from scr.utils import print_scenarios

def process_pkl_files(base_dir):
    pkl_list = []

    # Iterate over all batch directories
    for batch_dir in sorted(base_dir.iterdir(), key=lambda x: int(x.name) if x.name.isdigit() else float('inf')):
        if batch_dir.is_dir():
            # List all .pkl files in the current batch directory
            pkl_files = list(batch_dir.glob("*.pkl"))

            for pkl_file in pkl_files:
                try:
                    with open(pkl_file, 'rb') as file:
                        game_data = pickle.load(file)
                except IOError as e:
                    print(f"Error reading file {pkl_file}: {e}")
                    continue

                # Processing each pickle file
                pkl_list.extend(process_pkl_file(pkl_file, game_data))

    return pkl_list

def process_pkl_file(pkl_file, game_data):
    file_scenarios = []
    for data in game_data:
        game_won, guesses = data
        word, initial_state, difficulty, outcome = extract_info_from_filename(pkl_file)
        
        # Create a scenario dictionary for each data tuple
        scenario = {
            'word': word,
            'difficulty': difficulty,
            'outcome': outcome,
            'data': (game_won, guesses)
        }
        file_scenarios.append((pkl_file, scenario))  # Add scenario to the list

    return file_scenarios

def extract_info_from_filename(pkl_file):
    parts = pkl_file.stem.split('_from_')
    word_and_state = parts[0].split('_')
    word = '_'.join(word_and_state[:-1])
    initial_state = word_and_state[-1]
    difficulty, outcome = parts[1].split('_')[-2:]
    return word, initial_state, difficulty, outcome

# def print_scenarios(scenarios):
#     # Assuming this function is defined elsewhere
#     pass

# Process all pickle files
pkl_list = process_pkl_files(base_dataset_dir)

# Accessing an individual pickle file's content by index
index_to_access = 0  # Change this index to access different files
if index_to_access < len(pkl_list):
    file_path, scenario = pkl_list[index_to_access]
    print(f"Contents of {file_path}:")
    print_scenarios([scenario])  # Wrap scenario in a list for the function
else:
    print(f"No pickle file at index {index_to_access}")

No pickle file at index 0



In [None]:
pkl_list = []

# Iterate over all batch directories
for batch_dir in sorted(base_dataset_dir.iterdir(), \
    key=lambda x: int(x.name) if x.name.isdigit() else float('inf')):
    if batch_dir.is_dir():
        # List all .pkl files in the current batch directory
        pkl_files = list(batch_dir.glob("*.pkl"))

        for pkl_file in pkl_files:
            with open(pkl_file, 'rb') as file:
                game_data = pickle.load(file)
                # Extract information from file name
                parts = pkl_file.stem.split('_from_')
                word_and_statet = parts[0].split('_')
                word = '_'.join(word_and_state[:-1])
                initial_state = word_and_state[-1]
                difficulty, outcome = parts[1].split('_')[-2:]

                # Assuming game_data is a list of tuples (game_won, guesses)
                for data in game_data:
                    game_won, guesses = data
                    # Create a scenario dictionary for each data tuple
                    scenario = {
                        'word': word,
                        'difficulty': difficulty,
                        'outcome': outcome,
                        'data': (game_won, guesses)
                    }
                    pkl_list.append((pkl_file, scenario))  # Add scenario to the list

# Accessing an individual pickle file's content by index
index_to_access = 0  # Change this index to access different files
if index_to_access < len(pkl_list):
    file_path, scenario = pkl_list[index_to_access]
    print(f"Contents of {file_path}:")
    print_scenarios([scenario])  # Wrap scenario in a list for the function
else:
    print(f"No pickle file at index {index_to_access}")