In [None]:
import torch
from torch.utils.data import Dataset
from datasets import load_dataset
import numpy as np
from tqdm import tqdm
import pickle

class SudokuDataset(Dataset):
    """
    Unified Sudoku dataset combining multiple sources

    Args:
        process: Processing method for sudoku boards
            - "one-hot": 729-dim flat vector with one-hot encoding per cell
            - "tokenized": (81, 9) matrix with 9-dim vector per cell
            - "number": 81-dim vector with numbers (0 for empty)
        base_train_max: Maximum number of train samples from Ritvik19 dataset (default: None = use all)
        extreme_train_max: Maximum number of train samples from extreme dataset (default: None = use all)
    """

    def __init__(self, process="tokenized", base_train_max=None, base_test_max=None, extreme_train_max=None, extreme_test_max=None):
        super().__init__()

        if process not in ["one-hot", "tokenized", "number"]:
            raise ValueError(f"Invalid process type: {process}. Must be 'one-hot', 'tokenized', or 'number'")

        self.process = process
        self.base_train_max = base_train_max
        self.base_test_max = base_test_max
        self.extreme_train_max = extreme_train_max
        self.extreme_test_max = extreme_test_max

        # Load and process all datasets
        self._load_and_process_datasets()

        print(f"\nDataset loaded successfully!")
        print(f"Processing method: {process}")
        if base_train_max:
            print(f"Base train max samples: {base_train_max}")
        if base_test_max:
            print(f"Base test max samples: {base_test_max}")
        if extreme_train_max:
            print(f"Extreme train max samples: {extreme_train_max}")
        if extreme_test_max:
            print(f"Extreme test max samples: {extreme_test_max}")

    def _load_and_process_datasets(self):
        """Load all datasets and process them according to the specified method"""

        # Initialize subspilt storage for train and test
        self.train_subsplits = {}
        self.test_subsplits = {}

        # Load Ritvik19/Sudoku-Dataset
        print("Loading Ritvik19/Sudoku-Dataset...")
        ritvik_dataset = load_dataset("Ritvik19/Sudoku-Dataset")

        # Process Ritvik19 train split into subsplits by missing count
        print("Processing Ritvik19 train split...")

        # Shuffle and select if base_train_max is specified
        ritvik_train = ritvik_dataset['train']
        if self.base_train_max:
            ritvik_train = ritvik_train.shuffle(seed=42)
            ritvik_train = ritvik_train.select(range(min(self.base_train_max, len(ritvik_train))))
            print(f"  Selected {len(ritvik_train)} items from Ritvik19 train (originally {len(ritvik_dataset['train'])})")

        for item in tqdm(ritvik_train, desc="Processing Ritvik19 train"):
            missing = item['missing']
            if missing not in self.train_subsplits:
                self.train_subsplits[missing] = []

            question_tensor = self._process_board(item['puzzle'])
            answer_tensor = self._process_board(item['solution'])
            self.train_subsplits[missing].append((question_tensor, answer_tensor))

        # Process Ritvik19 validation split into test subsplits
        print("Processing Ritvik19 validation split...")
        ritvik_test = ritvik_dataset['validation']
        if self.base_test_max:
            ritvik_test = ritvik_test.shuffle(seed=42)
            ritvik_test = ritvik_test.select(range(min(self.base_test_max, len(ritvik_test))))
            print(f"  Selected {len(ritvik_test)} items from Ritvik19 validation (originally {len(ritvik_dataset['validation'])})")

        for item in tqdm(ritvik_test, desc="Processing Ritvik19 validation", leave=True):
            missing = item['missing']
            if missing not in self.test_subsplits:
                self.test_subsplits[missing] = []

            question_tensor = self._process_board(item['puzzle'])
            answer_tensor = self._process_board(item['solution'])
            self.test_subsplits[missing].append((question_tensor, answer_tensor))

        # Load sapientinc/sudoku-extreme
        print("Loading sapientinc/sudoku-extreme...")
        sudoku_extreme = load_dataset("sapientinc/sudoku-extreme")

        # Process extreme train split with optional sampling
        print("Processing extreme train split...")

        # Shuffle and select if extreme_train_max is specified
        extreme_train = sudoku_extreme['train']
        if self.extreme_train_max:
            extreme_train = extreme_train.shuffle(seed=42)
            extreme_train = extreme_train.select(range(min(self.extreme_train_max, len(extreme_train))))
            print(f"  Selected {len(extreme_train)} items from extreme train (originally {len(sudoku_extreme['train'])})")

        self.train_subsplits['extreme'] = self._process_split(
            extreme_train['question'],
            extreme_train['answer'],
            desc="Processing extreme train"
        )

        # Process extreme test split (always use all data)
        print("Processing extreme test split...")
        extreme_test = sudoku_extreme['test']

        if self.extreme_test_max:
            extreme_test = extreme_test.shuffle(seed=42)
            extreme_test = extreme_test.select(range(min(self.extreme_test_max, len(extreme_test))))
            print(f"  Selected {len(extreme_test)} items from extreme test (originally {len(sudoku_extreme['test'])})")

        self.test_subsplits['extreme'] = self._process_split(
            extreme_test['question'],
            extreme_test['answer'],
            desc="Processing extreme test"
        )

        # Load SakanaAI/Sudoku-Bench challenge
        print("Loading SakanaAI/Sudoku-Bench challenge...")
        sudoku_bench_ch = load_dataset("SakanaAI/Sudoku-Bench", 'challenge_100')
        self.challenge = self._process_split(
            sudoku_bench_ch['test']['initial_board'],
            sudoku_bench_ch['test']['solution'],
            desc="Processing challenge"
        )

        # Load SakanaAI/Sudoku-Bench nikoli
        print("Loading SakanaAI/Sudoku-Bench nikoli...")
        sudoku_bench_nik = load_dataset("SakanaAI/Sudoku-Bench", 'nikoli_100')
        self.nikoli = self._process_split(
            sudoku_bench_nik['test']['initial_board'],
            sudoku_bench_nik['test']['solution'],
            desc="Processing nikoli"
        )

    def _process_split(self, questions, answers, desc="Processing"):
        """Process a dataset split according to the specified method"""
        processed_data = []

        for q, a in tqdm(zip(questions, answers), total=len(questions), desc=desc, leave=True):
            question_tensor = self._process_board(q)
            answer_tensor = self._process_board(a)
            processed_data.append((question_tensor, answer_tensor))

        return processed_data

    def _process_board(self, board_str):
        """Process a single sudoku board string according to the specified method"""

        if self.process == "number":
            # Convert to 81-dim vector with 0 for empty cells
            board_vec = []
            for char in board_str:
                if char == '.' or char == '0':
                    board_vec.append(0)
                else:
                    board_vec.append(int(char))
            return torch.tensor(board_vec, dtype=torch.float32)

        elif self.process == "one-hot":
            # Convert to 729-dim flat one-hot vector
            board_vec = []
            for char in board_str:
                one_hot = [0.0] * 9
                if char != '.' and char != '0':
                    one_hot[int(char) - 1] = 1.0
                board_vec.extend(one_hot)
            return torch.tensor(board_vec, dtype=torch.float32)

        elif self.process == "tokenized":
            # Convert to (81, 9) matrix
            board_matrix = []
            for char in board_str:
                token = [0.0] * 9
                if char != '.' and char != '0':
                    token[int(char) - 1] = 1.0
                board_matrix.append(token)
            return torch.tensor(board_matrix, dtype=torch.float32)

    def get_dataloader(self, split_name, batch_size=32, shuffle=False, min_empty=20, max_empty=64, include_extreme=True):
        """
        Create a DataLoader for a specific split

        Args:
            split_name: Name of the split ('train', 'test', 'test_extreme', 'challenge', 'nikoli')
            batch_size: Batch size for the DataLoader
            shuffle: Whether to shuffle the data
            min_empty: Minimum number of empty cells for data selection
            max_empty: Maximum number of empty cells for data selection
            include_extreme: Whether to include extreme data

        Returns:
            DataLoader for the specified split
        """
        from torch.utils.data import DataLoader

        # Get the appropriate data based on split_name
        if split_name == 'train':
            data = []
            # Add data from specified range
            for missing in range(min_empty, max_empty + 1):
                if missing in self.train_subsplits:
                    data.extend(self.train_subsplits[missing])
            # Add extreme data if included
            if include_extreme and 'extreme' in self.train_subsplits:
                data.extend(self.train_subsplits['extreme'])

        elif split_name == 'test':
            data = []
            # Add data from specified range
            for missing in range(min_empty, max_empty + 1):
                if missing in self.test_subsplits:
                    data.extend(self.test_subsplits[missing])
            # Add extreme data if included
            if include_extreme and 'extreme' in self.test_subsplits:
                data.extend(self.test_subsplits['extreme'])

        elif split_name == 'test_extreme':
            data = self.test_subsplits.get('extreme', [])

        elif split_name == 'challenge':
            data = self.challenge

        elif split_name == 'nikoli':
            data = self.nikoli

        else:
            raise ValueError(f"Invalid split name: {split_name}")

        class SplitDataset(Dataset):
            def __init__(self, data):
                self.data = data

            def __len__(self):
                return len(self.data)

            def __getitem__(self, idx):
                return self.data[idx]

        split_dataset = SplitDataset(data)
        return DataLoader(split_dataset, batch_size=batch_size, shuffle=shuffle)

    def get_empty_distribution(self, split_name, min_empty=20, max_empty=64):
        """Get distribution of empty cells for train or test split"""
        if split_name == 'train':
            subsplits = self.train_subsplits
        elif split_name == 'test':
            subsplits = self.test_subsplits
        else:
            return None

        distribution = {}
        for key, data in subsplits.items():
            if key == 'extreme':
                continue  # Skip extreme for distribution
            if min_empty <= key <= max_empty:
                distribution[key] = len(data)

        return distribution

def sample_to_grid(tensor, process="tokenized"):
    """
    Convert a processed sudoku tensor back to 9x9 grid format for display

    Args:
        tensor: Processed sudoku tensor
        process: Processing method used ("one-hot", "tokenized", or "number")

    Returns:
        String representation of the sudoku grid
    """
    # Convert tensor to list of numbers (1-9, 0 for empty)
    if process == "number":
        # Already in number format, just convert 0 to '.'
        numbers = tensor.cpu().numpy().astype(int)
    elif process == "one-hot":
        # Reshape from 729 to (81, 9) and get argmax
        reshaped = tensor.reshape(81, 9)
        numbers = []
        for cell in reshaped:
            if cell.sum() == 0:
                numbers.append(0)
            else:
                numbers.append(cell.argmax().item() + 1)
    elif process == "tokenized":
        # Shape is (81, 9), get argmax for each cell
        numbers = []
        for cell in tensor:
            if cell.sum() == 0:
                numbers.append(0)
            else:
                numbers.append(cell.argmax().item() + 1)
    else:
        raise ValueError(f"Unknown process type: {process}")

    # Build the grid string
    grid_str = ""
    for i in range(9):
        if i % 3 == 0 and i != 0:
            grid_str += "------+-------+------\n"

        for j in range(9):
            if j % 3 == 0 and j != 0:
                grid_str += "| "

            idx = i * 9 + j
            if numbers[idx] == 0:
                grid_str += ". "
            else:
                grid_str += f"{numbers[idx]} "

        grid_str += "\n"

    return grid_str

def display_sample(dataset, split='train', index=0):
    """Display a sample from the dataset in grid format"""
    # Note: This function would need to be updated to work with the new structure
    pass

def save_dataset(dataset, filepath):
    """Save dataset instance using pickle"""
    with open(filepath, 'wb') as f:
        pickle.dump(dataset, f, protocol=pickle.HIGHEST_PROTOCOL)
    print(f"Dataset saved to {filepath}")

def load_saved_dataset(filepath):
    """Load dataset instance from pickle file"""
    with open(filepath, 'rb') as f:
        dataset = pickle.load(f)
    print(f"Dataset loaded from {filepath}")
    return dataset

# dataset = SudokuDataset(process="tokenized", base_train_max=100000, extreme_train_max=100000)
# save_dataset(dataset, "/content/drive/MyDrive/Masters Thesis/Idea Experiments/sudoku_tokenized_extended.pkl")
# dataset = load_saved_dataset("/content/drive/MyDrive/Masters Thesis/Idea Experiments/sudoku_tokenized_extended.pkl")

In [None]:
import torch
import random
import numpy as np
from tqdm import tqdm


def augment_dataset(dataset, augment_factor=32, seed=None):
    """
    Augment training data in-place with geometric transformations and number permutations.

    Args:
        dataset: SudokuDataset instance to augment
        augment_factor: Total number of samples per original (including original)
        seed: Random seed for reproducibility

    Returns:
        The same dataset object with augmented train_subsplits
    """
    if seed is not None:
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)

    def rotate_90(tensor):
        """Rotate 90 degrees clockwise"""
        reshaped = tensor.view(9, 9, -1)
        rotated = reshaped.rot90(-1, [0, 1])
        return rotated.reshape(81, -1)

    def rotate_180(tensor):
        """Rotate 180 degrees"""
        reshaped = tensor.view(9, 9, -1)
        rotated = reshaped.rot90(-2, [0, 1])
        return rotated.reshape(81, -1)

    def rotate_270(tensor):
        """Rotate 270 degrees clockwise"""
        reshaped = tensor.view(9, 9, -1)
        rotated = reshaped.rot90(-3, [0, 1])
        return rotated.reshape(81, -1)

    def flip_vertical(tensor):
        """Flip vertically"""
        reshaped = tensor.view(9, 9, -1)
        flipped = reshaped.flip(0)
        return flipped.reshape(81, -1)

    def flip_horizontal(tensor):
        """Flip horizontally"""
        reshaped = tensor.view(9, 9, -1)
        flipped = reshaped.flip(1)
        return flipped.reshape(81, -1)

    def transpose(tensor):
        """Transpose (flip along main diagonal)"""
        reshaped = tensor.view(9, 9, -1)
        transposed = reshaped.transpose(0, 1)
        return transposed.reshape(81, -1)

    def anti_transpose(tensor):
        """Anti-transpose (flip along anti-diagonal)"""
        reshaped = tensor.view(9, 9, -1)
        anti_transposed = reshaped.rot90(-1, [0, 1]).transpose(0, 1)
        return anti_transposed.reshape(81, -1)

    def apply_number_permutation(input_tensor, target_tensor):
        """Apply random number permutation"""
        perm = torch.randperm(9)
        return input_tensor[:, perm], target_tensor[:, perm]

    # Geometric transformations
    geometric_transforms = [
        lambda x: x,  # Identity
        rotate_90,
        rotate_180,
        rotate_270,
        flip_vertical,
        flip_horizontal,
        transpose,
        anti_transpose,
    ]

    print(f"Augmenting training data with factor {augment_factor}...")

    # Count total samples
    total_samples = sum(len(data) for data in dataset.train_subsplits.values())

    with tqdm(total=total_samples, desc="Augmenting data", unit="samples") as pbar:
        # Augment each subsplit
        for key in dataset.train_subsplits:
            original_data = dataset.train_subsplits[key]
            augmented_data = []

            pbar.set_description(f"Augmenting subsplit '{key}'")

            for input_t, target_t in original_data:
                # Add original
                augmented_data.append((input_t, target_t))

                # Add augmented versions
                for _ in range(augment_factor - 1):
                    # Random geometric transformation
                    geo_idx = random.randint(0, 7)
                    aug_input = geometric_transforms[geo_idx](input_t.clone())
                    aug_target = geometric_transforms[geo_idx](target_t.clone())

                    # 50% chance for number permutation
                    if random.random() < 0.5:
                        aug_input, aug_target = apply_number_permutation(aug_input, aug_target)

                    augmented_data.append((aug_input, aug_target))

                pbar.update(1)

            # Replace with augmented data
            dataset.train_subsplits[key] = augmented_data
            tqdm.write(f"  Subsplit '{key}': {len(original_data)} -> {len(augmented_data)} samples")

    print(f"Augmentation complete! Total samples: {sum(len(data) for data in dataset.train_subsplits.values())}")

    return dataset