trying to implement the HRM model architecture


Input layer:
- Puzzle embedding and sequence embedding

deep supervision
- runs the input and y_true from the dataloader
- run one HRM block
- find loss from cross entropyu
- train / backprop
- 

In [5]:
from torch import nn
from torch.utils.data import DataLoader

path_data = "./data/sudoku-extreme-1k-aug-1000"

In [6]:
import json
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

# Load dataset metadata
with open(f"{path_data}/train/dataset.json", "r") as f:
    dataset_metadata = json.load(f)

with open(f"{path_data}/identifiers.json", "r") as f:
    identifiers = json.load(f)

print("Dataset Metadata:")
for key, value in dataset_metadata.items():
    print(f"  {key}: {value}")

print(f"\nIdentifiers: {identifiers}")

# Load dataset arrays
train_inputs = np.load(f"{path_data}/train/all__inputs.npy")
train_labels = np.load(f"{path_data}/train/all__labels.npy")
train_puzzle_ids = np.load(f"{path_data}/train/all__puzzle_identifiers.npy")
train_puzzle_indices = np.load(f"{path_data}/train/all__puzzle_indices.npy")
train_group_indices = np.load(f"{path_data}/train/all__group_indices.npy")

print(f"\nDataset Shapes:")
print(f"  Inputs: {train_inputs.shape}")
print(f"  Labels: {train_labels.shape}")
print(f"  Puzzle IDs: {train_puzzle_ids.shape}")
print(f"  Puzzle Indices: {train_puzzle_indices.shape}")
print(f"  Group Indices: {train_group_indices.shape}")

print(f"\nData Types:")
print(f"  Inputs: {train_inputs.dtype}")
print(f"  Labels: {train_labels.dtype}")
print(f"  Puzzle IDs: {train_puzzle_ids.dtype}")

Dataset Metadata:
  pad_id: 0
  ignore_label_id: 0
  blank_identifier_id: 0
  vocab_size: 11
  seq_len: 81
  num_puzzle_identifiers: 1
  total_groups: 1000
  mean_puzzle_examples: 1.0
  sets: ['all']

Identifiers: ['<blank>']

Dataset Shapes:
  Inputs: (1001000, 81)
  Labels: (1001000, 81)
  Puzzle IDs: (1001000,)
  Puzzle Indices: (1001001,)
  Group Indices: (1001,)

Data Types:
  Inputs: int64
  Labels: int64
  Puzzle IDs: int32


In [7]:
# Display the first 10 Sudoku puzzles
def display_sudoku_grid(puzzle, title="Sudoku"):
    """Display a single 9x9 Sudoku grid with proper formatting"""
    grid = puzzle.reshape(9, 9)
    print(f"\n{title}:")
    print("  " + " ".join([str(i) for i in range(9)]))
    print("  " + "-" * 17)
    for i, row in enumerate(grid):
        row_str = f"{i}|"
        for j, val in enumerate(row):
            if val == 0:
                row_str += " ."
            else:
                row_str += f" {val}"
        print(row_str)

def display_puzzle_pair(idx, input_puzzle, solution_puzzle):
    """Display input and solution side by side"""
    print(f"\n{'='*50}")
    print(f"PUZZLE {idx + 1}")
    print(f"Group Index: {train_group_indices[idx]}")
    print(f"Puzzle ID: {train_puzzle_ids[idx]}")
    print(f"{'='*50}")
    
    # Convert to grids
    input_grid = input_puzzle.reshape(9, 9)
    solution_grid = solution_puzzle.reshape(9, 9)
    
    print("\nINPUT (_ = blank)        SOLUTION")
    print("  0 1 2 3 4 5 6 7 8      0 1 2 3 4 5 6 7 8")
    print("  -----------------      -----------------")
    
    for i in range(9):
        # Input row
        input_row = f"{i}|"
        
        for val in input_grid[i]:
            val -= 1
            if val == 0:
                input_row += " _"
            else:
                input_row += f" {val}"
        
        # Solution row
        solution_row = f"    {i}|"
        for val in solution_grid[i]:
            val -= 1
            solution_row += f" {val}"
            
        print(input_row + solution_row)
    
    # Count filled vs blank cells
    filled_cells = np.sum(input_puzzle != 0)
    blank_cells = np.sum(input_puzzle == 0)
    print(f"\nStatistics: {filled_cells} filled, {blank_cells} blank cells")

# Display first 10 puzzles
print("FIRST 10 SUDOKU PUZZLES IN DATASET")
print("=" * 60)

for i in range(min(10, len(train_inputs))):
    display_puzzle_pair(i, train_inputs[i], train_labels[i])
    
    # Add a small separator between puzzles
    if i < 9:
        print("\n" + "-" * 60)

FIRST 10 SUDOKU PUZZLES IN DATASET

PUZZLE 1
Group Index: 0
Puzzle ID: 0

INPUT (_ = blank)        SOLUTION
  0 1 2 3 4 5 6 7 8      0 1 2 3 4 5 6 7 8
  -----------------      -----------------
0| _ _ _ _ _ _ _ 8 _    0| 9 2 6 3 5 1 4 8 7
1| 5 _ _ _ _ _ 3 _ 1    1| 5 8 4 7 6 2 3 9 1
2| _ _ 7 9 8 _ _ _ _    2| 1 3 7 9 8 4 6 2 5
3| _ _ _ _ _ 3 _ _ 4    3| 8 7 1 6 9 3 2 5 4
4| _ 5 2 _ 4 7 _ _ 9    4| 3 5 2 1 4 7 8 6 9
5| 6 _ _ _ 2 _ 1 _ 3    5| 6 4 9 8 2 5 1 7 3
6| _ 6 _ 5 _ _ _ _ _    6| 4 6 8 5 3 9 7 1 2
7| _ 9 3 _ _ _ _ 4 8    7| 7 9 3 2 1 6 5 4 8
8| _ _ _ _ 7 _ _ _ _    8| 2 1 5 4 7 8 9 3 6

Statistics: 81 filled, 0 blank cells

------------------------------------------------------------

PUZZLE 2
Group Index: 1001
Puzzle ID: 0

INPUT (_ = blank)        SOLUTION
  0 1 2 3 4 5 6 7 8      0 1 2 3 4 5 6 7 8
  -----------------      -----------------
0| _ _ 8 _ 6 _ _ _ _    0| 1 7 8 3 6 9 5 4 2
1| 5 9 6 _ 8 _ _ _ 1    1| 5 9 6 2 8 4 3 7 1
2| _ _ _ _ _ 1 _ _ 9    2| 3 2 4 7 5 1 6 8 9
3| 4