### Imports 

In [1]:
import pandas as pd
import json
import matplotlib.pyplot as plt
import chess
import torch
import sys
import os

from tqdm import tqdm 

### Setup

In [2]:
project_root = os.path.abspath(os.path.join(os.getcwd(), '../../'))
if project_root not in sys.path:
    sys.path.insert(0, project_root)

# Attempt to import create_tensor
try:
    from src.io.to_tensor import create_tensor 
except ImportError:
    print(f"Error: Could not import create_tensor from src.io.to_tensor.")
    print(f"Ensure 'src' directory is in sys.path ('{project_root}' was added) and contains io/to_tensor.py.")
    print(f"Current sys.path includes: {sys.path[:3]} ...") # 
    pass


# Configuration
FILE_PATH = "../../data/raw/lichess_eval/lichess_db_eval.jsonl"
OUTPUT_DIR = '../../data/processed/lichess_eval/'
BATCH_SIZE = 1_000_000  
MAX_CP = 2000.0  

os.makedirs(OUTPUT_DIR, exist_ok=True) 

 ### Processing files

In [3]:
batch_num = 0
entries_in_current_batch = 0
tensors_batch = []
labels_batch = []
total_valid_entries_processed = 0
lines_read = 0

print(f"Starting processing of: {FILE_PATH}")
print(f"Saving processed tensors in batches of {BATCH_SIZE} to: {OUTPUT_DIR}")

Starting processing of: ../../data/raw/lichess_eval/lichess_db_eval.jsonl
Saving processed tensors in batches of 1000000 to: ../../data/processed/lichess_eval/


In [4]:
try:
    with open(FILE_PATH, 'r') as file:
        for line in tqdm(file, desc="Processing JSONL file", unit="lines"):
            lines_read += 1
            try:
                obj = json.loads(line.strip())
                fen = obj.get('fen')
                cp = None

                if obj.get('evals') and len(obj['evals']) > 0:
                    if obj['evals'][0].get('pvs') and len(obj['evals'][0]['pvs']) > 0:
                        cp_value = obj['evals'][0]['pvs'][0].get('cp')
                        if cp_value is not None:
                            cp = int(cp_value)

                if fen is not None and cp is not None:
                    try:
                        board = chess.Board(fen)
                        tensor = create_tensor(board)
                        
                        clipped_cp = max(min(cp, MAX_CP), -MAX_CP)
                        normalized_label = clipped_cp / MAX_CP
                        
                        tensors_batch.append(tensor)
                        labels_batch.append(normalized_label)
                        entries_in_current_batch += 1
                        total_valid_entries_processed += 1

                    except ValueError as e:
                        # print(f"Skipping invalid FEN on line {lines_read} ('{fen}'): {e}")
                        pass 
                    except Exception as e:
                        # print(f"Error processing FEN on line {lines_read} ('{fen}'): {e}")
                        pass

                if entries_in_current_batch >= BATCH_SIZE:
                    if tensors_batch:
                        input_tensors = torch.stack(tensors_batch)
                        label_tensors = torch.tensor(labels_batch, dtype=torch.float32).unsqueeze(1)
                        
                        processed_data_batch = {'inputs': input_tensors, 'labels': label_tensors}
                        output_tensor_path = os.path.join(OUTPUT_DIR, f'{batch_num + 1}.pt')
                        torch.save(processed_data_batch, output_tensor_path)
                        print(f"\nSaved batch {batch_num + 1} to {output_tensor_path}. "
                              f"Inputs shape: {input_tensors.shape}, Labels shape: {label_tensors.shape}")

                        tensors_batch = []
                        labels_batch = []
                        entries_in_current_batch = 0
                        batch_num += 1
                    else:
                        # print(f"\nWarning: Batch {batch_num + 1} was marked full but no tensors collected.")
                        pass # Also hiding this warning as per general request to hide errors/warnings


            except json.JSONDecodeError:
                # print(f"Skipping line {lines_read} due to JSON decode error: {line.strip()}")
                pass
            except Exception as e: 
                # print(f"Skipping line {lines_read} due to an unexpected error: {e} - Line: {line.strip()}")
                pass

    if tensors_batch: 
        input_tensors = torch.stack(tensors_batch)
        label_tensors = torch.tensor(labels_batch, dtype=torch.float32).unsqueeze(1)
        
        processed_data_batch = {'inputs': input_tensors, 'labels': label_tensors}
        output_tensor_path = os.path.join(OUTPUT_DIR, f'{batch_num + 1}.pt')
        torch.save(processed_data_batch, output_tensor_path)
        print(f"\nSaved final batch {batch_num + 1} to {output_tensor_path}. "
              f"Inputs shape: {input_tensors.shape}, Labels shape: {label_tensors.shape}")
    elif total_valid_entries_processed == 0:
        print("\nNo valid data processed from the file. No .pt files were created.")
    else:
        print("\nFinished processing. No remaining data for a final partial batch, or the last batch was empty.")

except FileNotFoundError:
    print(f"Error: The file {FILE_PATH} was not found.")
except Exception as e:
    print(f"An unexpected error occurred during file processing: {e}")

Processing JSONL file: 1208241lines [14:00, 1462.86lines/s]


Saved batch 1 to ../../data/processed/lichess_eval/1.pt. Inputs shape: torch.Size([1000000, 28, 8, 8]), Labels shape: torch.Size([1000000, 1])


Processing JSONL file: 2390633lines [27:20, 1418.24lines/s]


Saved batch 2 to ../../data/processed/lichess_eval/2.pt. Inputs shape: torch.Size([1000000, 28, 8, 8]), Labels shape: torch.Size([1000000, 1])


Processing JSONL file: 3131992lines [37:17, 1399.50lines/s]


KeyboardInterrupt: 

In [None]:
print(f"\n--- Processing Summary ---")
print(f"Total lines read from file: {lines_read}")
print(f"Total valid entries processed and saved: {total_valid_entries_processed}")
print(f"Number of .pt files created: {batch_num + 1 if tensors_batch and total_valid_entries_processed > 0 else batch_num}")
print(f"Output directory: {OUTPUT_DIR}")
print("Processing complete.")