trying to implement the HRM model architecture


Input layer:
- Puzzle embedding and sequence embedding

deep supervision
- runs the input and y_true from the dataloader
- run one HRM block
- find loss from cross entropyu
- train / backprop
- 

In [1]:
from torch import nn
from torch.utils.data import DataLoader

path_data = "./data/sudoku-extreme-1k-aug-1000"

In [2]:
import json
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

# Load dataset metadata
with open(f"{path_data}/train/dataset.json", "r") as f:
    dataset_metadata = json.load(f)

with open(f"{path_data}/identifiers.json", "r") as f:
    identifiers = json.load(f)

print("Dataset Metadata:")
for key, value in dataset_metadata.items():
    print(f"  {key}: {value}")

print(f"\nIdentifiers: {identifiers}")

# Load dataset arrays
train_inputs = np.load(f"{path_data}/train/all__inputs.npy")
train_labels = np.load(f"{path_data}/train/all__labels.npy")
train_puzzle_ids = np.load(f"{path_data}/train/all__puzzle_identifiers.npy")
train_puzzle_indices = np.load(f"{path_data}/train/all__puzzle_indices.npy")
train_group_indices = np.load(f"{path_data}/train/all__group_indices.npy")

print(f"\nDataset Shapes:")
print(f"  Inputs: {train_inputs.shape}")
print(f"  Labels: {train_labels.shape}")
print(f"  Puzzle IDs: {train_puzzle_ids.shape}")
print(f"  Puzzle Indices: {train_puzzle_indices.shape}")
print(f"  Group Indices: {train_group_indices.shape}")

print(f"\nData Types:")
print(f"  Inputs: {train_inputs.dtype}")
print(f"  Labels: {train_labels.dtype}")
print(f"  Puzzle IDs: {train_puzzle_ids.dtype}")

Dataset Metadata:
  pad_id: 0
  ignore_label_id: 0
  blank_identifier_id: 0
  vocab_size: 11
  seq_len: 81
  num_puzzle_identifiers: 1
  total_groups: 1000
  mean_puzzle_examples: 1.0
  sets: ['all']

Identifiers: ['<blank>']

Dataset Shapes:
  Inputs: (1001000, 81)
  Labels: (1001000, 81)
  Puzzle IDs: (1001000,)
  Puzzle Indices: (1001001,)
  Group Indices: (1001,)

Data Types:
  Inputs: int64
  Labels: int64
  Puzzle IDs: int32


In [3]:


def display_puzzle_pair(idx, input_puzzle, solution_puzzle):
    """Display input and solution side by side"""
    print(f"\n{'='*50}")
    print(f"PUZZLE {idx + 1}")
    print(f"Group Index: {train_group_indices[idx]}")
    print(f"Puzzle ID: {train_puzzle_ids[idx]}")
    print(f"{'='*50}")
    
    # Convert to grids
    input_grid = input_puzzle.reshape(9, 9)
    solution_grid = solution_puzzle.reshape(9, 9)
    
    print("\nINPUT (_ = blank)        SOLUTION")
    print("  0 1 2 3 4 5 6 7 8      0 1 2 3 4 5 6 7 8")
    print("  -----------------      -----------------")

    pad_id = dataset_metadata["pad_id"] # from datajson file
    pad = input_puzzle[pad_id] # position of the pad
    for i in range(9):
        # Input row
        input_row = f"{i}|"
        
        for val in input_grid[i]:
            val -= pad
            if val == 0:
                input_row += " _"
            else:
                input_row += f" {val}"
        
        # Solution row
        solution_row = f"    {i}|"
        for val in solution_grid[i]:
            val -= pad
            solution_row += f" {val}"
            
        print(input_row + solution_row)
    
    # Count filled vs blank cells
    filled_cells = np.sum(input_puzzle != 0)
    blank_cells = np.sum(input_puzzle == 0)
    print(f"\nStatistics: {filled_cells} filled, {blank_cells} blank cells")

# Display first 10 puzzles
print("FIRST 10 SUDOKU PUZZLES IN DATASET")
print("=" * 60)

for i in range(min(10, len(train_inputs))):
    display_puzzle_pair(i, train_inputs[i], train_labels[i])
    
    # Add a small separator between puzzles
    if i < 9:
        print("\n" + "-" * 60)

FIRST 10 SUDOKU PUZZLES IN DATASET

PUZZLE 1
Group Index: 0
Puzzle ID: 0

INPUT (_ = blank)        SOLUTION
  0 1 2 3 4 5 6 7 8      0 1 2 3 4 5 6 7 8
  -----------------      -----------------
0| _ 1 _ _ 5 _ _ 8 _    0| 6 1 4 2 5 7 3 8 9
1| _ _ _ _ _ _ _ _ 2    1| 5 8 7 3 9 1 6 4 2
2| 9 _ _ 4 _ _ 7 _ _    2| 9 2 3 4 6 8 7 1 5
3| _ 7 _ 6 _ _ 1 _ _    3| 8 7 5 6 2 9 1 3 4
4| 1 _ _ 7 8 _ _ 5 _    4| 1 3 9 7 8 4 2 5 6
5| _ _ 6 _ _ 3 _ _ _    5| 2 4 6 5 1 3 8 9 7
6| 7 9 _ _ _ _ _ 2 _    6| 7 9 8 1 4 6 5 2 3
7| _ _ 1 9 _ _ 4 _ _    7| 3 5 1 9 7 2 4 6 8
8| 4 _ _ _ 3 _ _ _ _    8| 4 6 2 8 3 5 9 7 1

Statistics: 81 filled, 0 blank cells

------------------------------------------------------------

PUZZLE 2
Group Index: 1001
Puzzle ID: 0

INPUT (_ = blank)        SOLUTION
  0 1 2 3 4 5 6 7 8      0 1 2 3 4 5 6 7 8
  -----------------      -----------------
0| _ 3 _ _ 8 5 2 _ _    0| 1 3 9 7 8 5 2 4 6
1| 5 _ _ _ _ 2 _ _ 3    1| 5 7 8 4 6 2 9 1 3
2| _ _ 6 3 _ _ _ _ _    2| 4 2 6 3 9 1 7 5 8
3| 6

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from typing import Dict, Tuple, Optional
import numpy as np

class SudokuDataset(Dataset):
    """
    PyTorch Dataset for Sudoku puzzles optimized for machine learning training
    """
    
    def __init__(self, data_path: str, split: str = "train", transform=None):
        """
        Initialize the Sudoku dataset
        
        """
        self.data_path = data_path
        self.split = split
        self.transform = transform
        
        # Load metadata
        self._load_metadata()
        
        # Load data
        self._load_data()
        
    def _load_metadata(self):
        """Load metadata from dataset.json"""
        metadata_path = f"{self.data_path}/{self.split}/dataset.json"
        with open(metadata_path, 'r') as f:
            metadata = json.load(f)
            
        self.pad_id = metadata["pad_id"]
        self.blank_identifier_id = metadata["blank_identifier_id"]
        self.vocab_size = metadata["vocab_size"]
        self.seq_len = metadata["seq_len"]
        self.ignore_label_id = metadata.get("ignore_label_id", 0)
        
    def _load_data(self):
        """Load the numpy arrays and convert to tensors"""
        split_path = f"{self.data_path}/{self.split}"
        
        # Load and convert to PyTorch tensors
        self.inputs = torch.from_numpy(np.load(f"{split_path}/all__inputs.npy")).long()
        self.labels = torch.from_numpy(np.load(f"{split_path}/all__labels.npy")).long()
        self.puzzle_ids = torch.from_numpy(np.load(f"{split_path}/all__puzzle_identifiers.npy")).long()
        self.group_indices = torch.from_numpy(np.load(f"{split_path}/all__group_indices.npy")).long()
        
        print(f"Loaded {self.split} dataset:")
        print(f"  Samples: {len(self.inputs)}")
        print(f"  Input shape: {self.inputs.shape}")
        print(f"  Label shape: {self.labels.shape}")
        print(f"  Vocab size: {self.vocab_size}")
        print(f"  Sequence length: {self.seq_len}")
        
    def __len__(self):
        return len(self.inputs)
    
    def __getitem__(self, idx):
        """
        Get a single sample from the dataset
        
        Returns:
            Dict with input, target, and metadata
        """
        sample = {
            'input': self.inputs[idx],           # Shape: (81,) - flattened 9x9 grid
            'target': self.labels[idx],          # Shape: (81,) - solution grid
            'puzzle_id': self.puzzle_ids[idx],   # Puzzle identifier
            'group_id': self.group_indices[idx], # Group identifier
        }
        
        if self.transform:
            sample = self.transform(sample)
            
        return sample
    
    def get_metadata(self):
        """Return dataset metadata"""
        return {
            'pad_id': self.pad_id,
            'blank_identifier_id': self.blank_identifier_id,
            'vocab_size': self.vocab_size,
            'seq_len': self.seq_len,
            'ignore_label_id': self.ignore_label_id
        }

def collate_fn(batch):
    """
    Custom collate function for batching Sudoku samples
    
    Args:
        batch: List of samples from SudokuDataset
        
    Returns:
        Batched tensors
    """
    inputs = torch.stack([item['input'] for item in batch])
    targets = torch.stack([item['target'] for item in batch])
    puzzle_ids = torch.stack([item['puzzle_id'] for item in batch])
    group_ids = torch.stack([item['group_id'] for item in batch])
    
    return {
        'input': inputs,      # Shape: (batch_size, 81)
        'target': targets,    # Shape: (batch_size, 81)
        'puzzle_id': puzzle_ids,  # Shape: (batch_size,)
        'group_id': group_ids,    # Shape: (batch_size,)
    }

def create_dataloaders(data_path: str, 
                      batch_size: int = 32,
                      train_split: float = 0.8,
                      val_split: float = 0.2,
                      num_workers: int = 2,
                      shuffle: bool = True) -> Tuple[DataLoader, DataLoader, Optional[DataLoader]]:
    """
    Create train, validation, and test dataloaders
    
    Returns:
        Tuple of (train_loader, val_loader, test_loader)
    """
    
    # Load training dataset
    train_dataset = SudokuDataset(data_path, split="train")
    
    # Split train dataset into train/validation
    total_size = len(train_dataset)
    train_size = int(train_split * total_size)
    val_size = total_size - train_size
    
    train_subset, val_subset = random_split(
        train_dataset, 
        [train_size, val_size],
        generator=torch.Generator().manual_seed(42)  # For reproducibility
    )
    
    # Create dataloaders
    train_loader = DataLoader(
        train_subset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        collate_fn=collate_fn,
        pin_memory=torch.cuda.is_available()
    )
    
    val_loader = DataLoader(
        val_subset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        collate_fn=collate_fn,
        pin_memory=torch.cuda.is_available()
    )
    
    # Try to load test dataset
    test_loader = None
    try:
        test_dataset = SudokuDataset(data_path, split="test")
        test_loader = DataLoader(
            test_dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=num_workers,
            collate_fn=collate_fn,
            pin_memory=torch.cuda.is_available()
        )
        print(f"Test dataset loaded with {len(test_dataset)} samples")
    except:
        print("No test dataset found, using only train/validation split")
    
    return train_loader, val_loader, test_loader

# Create the dataloaders
print("CREATING SUDOKU DATALOADERS")
print("=" * 50)

train_loader, val_loader, test_loader = create_dataloaders(
    data_path=path_data,
    batch_size=32,
    train_split=0.8,
    val_split=0.2,
    num_workers=2,
    shuffle=True
)

print(f"\nDataLoader Summary:")
print(f"  Training batches: {len(train_loader)}")
print(f"  Validation batches: {len(val_loader)}")
if test_loader:
    print(f"  Test batches: {len(test_loader)}")

# Test the dataloaders
print(f"\nTesting DataLoaders:")
print("-" * 30)

# Get a sample batch from training loader
for batch_idx, batch in enumerate(train_loader):
    print(f"Sample Training Batch:")
    print(f"  Input shape: {batch['input'].shape}")
    print(f"  Target shape: {batch['target'].shape}")
    print(f"  Input dtype: {batch['input'].dtype}")
    print(f"  Target dtype: {batch['target'].dtype}")
    print(f"  Puzzle IDs: {batch['puzzle_id'][:5]}...")  # Show first 5
    print(f"  Input range: [{batch['input'].min().item()}, {batch['input'].max().item()}]")
    print(f"  Target range: [{batch['target'].min().item()}, {batch['target'].max().item()}]")
    break

# Display a sample from the batch
sample_input = batch['input'][0].numpy().reshape(9, 9)
sample_target = batch['target'][0].numpy().reshape(9, 9)

print(f"\nSample Puzzle from Batch:")
print("INPUT (0=blank):")
print(sample_input)
print("\nTARGET (solution):")
print(sample_target)

# Get dataset metadata
dataset = SudokuDataset(path_data, split="train")
metadata = dataset.get_metadata()
print(f"\nDataset Metadata:")
for key, value in metadata.items():
    print(f"  {key}: {value}")

print(f"\nDataLoaders ready for machine learning training!")

CREATING SUDOKU DATALOADERS
Loaded train dataset:
  Samples: 1001000
  Input shape: torch.Size([1001000, 81])
  Label shape: torch.Size([1001000, 81])
  Vocab size: 11
  Sequence length: 81
Loaded train dataset:
  Samples: 1001000
  Input shape: torch.Size([1001000, 81])
  Label shape: torch.Size([1001000, 81])
  Vocab size: 11
  Sequence length: 81
Loaded test dataset:
  Samples: 422786
  Input shape: torch.Size([422786, 81])
  Label shape: torch.Size([422786, 81])
  Vocab size: 11
  Sequence length: 81
Test dataset loaded with 422786 samples

DataLoader Summary:
  Training batches: 25025
  Validation batches: 6257
  Test batches: 13213

Testing DataLoaders:
------------------------------
Loaded test dataset:
  Samples: 422786
  Input shape: torch.Size([422786, 81])
  Label shape: torch.Size([422786, 81])
  Vocab size: 11
  Sequence length: 81
Test dataset loaded with 422786 samples

DataLoader Summary:
  Training batches: 25025
  Validation batches: 6257
  Test batches: 13213

Testin

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/opt/homebrew/anaconda3/envs/hrm/lib/python3.12/multiprocessing/spawn.py", line 122, in spawn_main
    exitcode = _main(fd, parent_sentinel)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/anaconda3/envs/hrm/lib/python3.12/multiprocessing/spawn.py", line 132, in _main
    self = reduction.pickle.load(from_parent)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: Can't get attribute 'SudokuDataset' on <module '__main__' (<class '_frozen_importlib.BuiltinImporter'>)>


KeyboardInterrupt: 