In [None]:
def sample_to_grid(tensor, process="tokenized"):
    """
    Convert a processed sudoku tensor back to 9x9 grid format for display

    Args:
        tensor: Processed sudoku tensor
        process: Processing method used ("one-hot", "tokenized", or "number")

    Returns:
        String representation of the sudoku grid
    """
    # Convert tensor to list of numbers (1-9, 0 for empty)
    if process == "number":
        # Already in number format, just convert 0 to '.'
        numbers = tensor.cpu().numpy().astype(int)
    elif process == "one-hot":
        # Reshape from 729 to (81, 9) and get argmax
        reshaped = tensor.reshape(81, 9)
        numbers = []
        for cell in reshaped:
            if cell.sum() == 0:
                numbers.append(0)
            else:
                numbers.append(cell.argmax().item() + 1)
    elif process == "tokenized":
        # Shape is (81, 9), get argmax for each cell
        numbers = []
        for cell in tensor:
            if cell.sum() == 0:
                numbers.append(0)
            else:
                numbers.append(cell.argmax().item() + 1)
    else:
        raise ValueError(f"Unknown process type: {process}")

    # Build the grid string
    grid_str = ""
    for i in range(9):
        if i % 3 == 0 and i != 0:
            grid_str += "------+-------+------\n"

        for j in range(9):
            if j % 3 == 0 and j != 0:
                grid_str += "| "

            idx = i * 9 + j
            if numbers[idx] == 0:
                grid_str += ". "
            else:
                grid_str += f"{numbers[idx]} "

        grid_str += "\n"

    return grid_str

def create_model():
    """Create model based on selected configuration"""
    if model_type == "StaticDEQ":
        model = StaticDEQ(
            T=T, C=C, D=static_hidden_dimension,
            L=static_iterations, N=static_num_weight_matrices,
            hid_activation=hidden_activation,
            output_activation=output_activation,
            weight_init=weight_initialization,
            bias=use_bias
        )
    elif model_type == "HierarchicalDEQ":
        Ls = [int(x.strip()) for x in hier_iterations_per_stage.split(',')]
        Ns = [int(x.strip()) for x in hier_weight_matrices_per_stage.split(',')]
        s_dims = [int(x.strip()) for x in hier_stage_dimensions.split(',')]
        model = HierarchicalDEQ(
            C=C, D=hier_hidden_dimension, Ls=Ls, Ns=Ns, s_dims=s_dims,
            hid_activation=hidden_activation,
            output_activation=output_activation,
            weight_init=weight_initialization,
            bias=use_bias,
            weight_share=hier_weight_share
        )
    elif model_type == "HyperDEQ":
        model = HyperDEQ(
            T=T, C=C, D=hyper_hidden_dimension,
            L=hyper_iterations, N=hyper_num_weight_generation_steps,
            H=hyper_num_heads, E=hyper_head_dimension,
            hid_activation=hidden_activation,
            output_activation=output_activation,
            weight_init=weight_initialization,
            bias=use_bias,
            weight_share=hyper_weight_share
        )
    else:
        raise ValueError(f"Unknown model type: {model_type}")

    if use_sudoku_positional:
        model = add_sudoku_positional_encoding(model)

    return model.to(device)

def create_optimizer(model):
    """Create optimizer based on configuration"""
    if optimizer_type == "Adam":
        return optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    elif optimizer_type == "AdamW":
        return optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    elif optimizer_type == "SGD":
        return optim.SGD(model.parameters(), lr=learning_rate, weight_decay=weight_decay, momentum=momentum_for_sgd)
    elif optimizer_type == "RMSprop":
        return optim.RMSprop(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    else:
        raise ValueError(f"Unknown optimizer type: {optimizer_type}")

def digit_entropy_loss(output):
    """
    数独の各セクション（行、列、3x3ブロック）で数字の分布エントロピーを最大化

    Args:
        output: (B, 81, 9) - モデルの出力（logitsまたはsoftmax後）

    Returns:
        エントロピー損失（スカラー、正の値）
    """
    B = output.shape[0]

    # Softmaxで確率化（既にsoftmax済みなら不要だが、安全のため）
    probs = F.softmax(output, dim=-1)
    probs_2d = probs.view(B, 9, 9, 9)  # (B, row, col, digit)

    eps = 1e-8
    total_entropy = 0
    count = 0

    # 行のエントロピー
    for row in range(9):
        row_probs = probs_2d[:, row, :, :]  # (B, 9, 9)
        # 各数字の平均確率
        digit_dist = row_probs.mean(dim=1)  # (B, 9)
        # エントロピー計算
        entropy = -(digit_dist * torch.log(digit_dist + eps)).sum(dim=-1)
        total_entropy += entropy.mean()
        count += 1

    # 列のエントロピー
    for col in range(9):
        col_probs = probs_2d[:, :, col, :]  # (B, 9, 9)
        digit_dist = col_probs.mean(dim=1)  # (B, 9)
        entropy = -(digit_dist * torch.log(digit_dist + eps)).sum(dim=-1)
        total_entropy += entropy.mean()
        count += 1

    # 3x3ブロックのエントロピー
    for br in range(3):
        for bc in range(3):
            block = probs_2d[:, br*3:(br+1)*3, bc*3:(bc+1)*3, :]
            block_flat = block.reshape(B, 9, 9)  # (B, cells_in_block, digits)
            digit_dist = block_flat.mean(dim=1)  # (B, 9)
            entropy = -(digit_dist * torch.log(digit_dist + eps)).sum(dim=-1)
            total_entropy += entropy.mean()
            count += 1

    # 平均エントロピー（27セクション分）
    avg_entropy = total_entropy / count

    # 最大エントロピー（log(9)）で正規化
    max_entropy = torch.log(torch.tensor(9.0))
    normalized_entropy = avg_entropy / max_entropy

    # 損失として返す（エントロピーが高いほど良いので負にする）
    return -normalized_entropy

def compute_loss(output, target, input_mask=None, entropy_weight=0.0):
    """Compute cross-entropy loss with optional entropy regularization"""
    # output: (B, T, C), target: (B, T, C)
    B, T_dim, C_dim = output.shape

    # Convert one-hot target to class indices
    target_indices = target.argmax(dim=-1)  # (B, T)

    # Reshape for cross-entropy
    output_flat = output.view(B * T_dim, C_dim)
    target_flat = target_indices.view(B * T_dim)

    if loss_on_empty_cells_only and input_mask is not None:
        # Only compute loss on empty cells
        mask_flat = input_mask.view(B * T_dim)
        ce_loss = nn.CrossEntropyLoss(reduction='none')(output_flat, target_flat)
        ce_loss = (ce_loss * mask_flat).sum() / mask_flat.sum()
    else:
        ce_loss = nn.CrossEntropyLoss()(output_flat, target_flat)

    # エントロピー正則化を追加
    if entropy_weight > 0:
        entropy_loss = digit_entropy_loss(output)
        total_loss = ce_loss + entropy_weight * entropy_loss
    else:
        total_loss = ce_loss

    return total_loss

def evaluate_accuracy(output, target, input_data):
    """Calculate various accuracy metrics"""
    B, T_dim, C_dim = output.shape

    # Get predictions
    pred_indices = output.argmax(dim=-1)  # (B, T)
    target_indices = target.argmax(dim=-1)  # (B, T)

    # Identify empty cells in input
    input_empty = (input_data.sum(dim=-1) == 0)  # (B, T)

    # Complete accuracy: completely solved puzzles
    complete_correct = (pred_indices == target_indices).all(dim=1).float().mean().item()

    # Empty cell accuracy: accuracy on originally empty cells
    if input_empty.sum() > 0:
        empty_correct = ((pred_indices == target_indices) * input_empty).sum().item()
        empty_total = input_empty.sum().item()
        empty_accuracy = empty_correct / empty_total
    else:
        empty_accuracy = 1.0

    # Overall accuracy: accuracy on all cells
    overall_accuracy = (pred_indices == target_indices).float().mean().item()

    return {
        'complete': complete_correct,
        'empty': empty_accuracy,
        'overall': overall_accuracy
    }

def evaluate_split(model, split_name, num_samples, criterion=None):
    """Evaluate model on a specific split"""
    model.eval()
    # Use batch_size=1 to handle variable sizes
    # For train and test, use the specified parameters
    if split_name in ['train', 'test']:
        dataloader = dataset.get_dataloader(split_name, batch_size=1, shuffle=False,
                                            min_empty=min_empty, max_empty=max_empty,
                                            include_extreme=include_extreme)
    else:
        # For test_extreme, challenge, nikoli, no need for empty range parameters
        dataloader = dataset.get_dataloader(split_name, batch_size=1, shuffle=False)

    total_loss = 0
    all_metrics = {'complete': [], 'empty': [], 'overall': []}
    samples_evaluated = 0

    with torch.no_grad():
        for i, (input_data, target_data) in enumerate(dataloader):
            if samples_evaluated >= num_samples:
                break

            # Skip non-81 cell puzzles
            if input_data.shape[1] != 81:
                continue

            input_data = input_data.to(device)
            target_data = target_data.to(device)

            # Forward pass
            output = model(input_data)

            # Compute loss if criterion provided
            if criterion:
                input_empty = (input_data.sum(dim=-1) == 0).float()
                loss = compute_loss(output, target_data, input_empty)
                total_loss += loss.item() * input_data.size(0)

            # Calculate accuracies
            metrics = evaluate_accuracy(output, target_data, input_data)
            for key in metrics:
                all_metrics[key].append(metrics[key])

            samples_evaluated += input_data.size(0)

    # Average metrics
    avg_metrics = {key: np.mean(values) if values else 0.0 for key, values in all_metrics.items()}
    avg_loss = total_loss / samples_evaluated if samples_evaluated > 0 and criterion else 0.0

    return avg_metrics, avg_loss

def evaluate_empty_distribution(model, split_name):
    """Evaluate model accuracy by number of empty cells"""
    model.eval()

    # Get subsplits from dataset
    if split_name == 'train':
        subsplits = dataset.train_subsplits
    elif split_name == 'test':
        subsplits = dataset.test_subsplits
    else:
        return None

    empty_dist_metrics = {}

    with torch.no_grad():
        for empty_count in range(min_empty, max_empty + 1):
            if empty_count not in subsplits:
                continue

            data_list = subsplits[empty_count]
            if len(data_list) == 0:
                continue

            complete_accs = []
            empty_accs = []
            overall_accs = []

            # Evaluate samples for this empty count
            for input_data, target_data in data_list[:100]:  # Limit to 100 samples per empty count
                # Skip non-81 cell puzzles
                if input_data.shape[0] != 81:
                    continue

                input_data = input_data.unsqueeze(0).to(device)
                target_data = target_data.unsqueeze(0).to(device)

                # Forward pass
                output = model(input_data)

                # Calculate accuracies
                metrics = evaluate_accuracy(output, target_data, input_data)
                complete_accs.append(metrics['complete'])
                empty_accs.append(metrics['empty'])
                overall_accs.append(metrics['overall'])

            if complete_accs:
                empty_dist_metrics[empty_count] = {
                    'complete': np.mean(complete_accs),
                    'empty': np.mean(empty_accs),
                    'overall': np.mean(overall_accs),
                    'count': len(complete_accs)
                }

    return empty_dist_metrics

def save_checkpoint(model, optimizer, history, samples_processed):
    """Save model checkpoint"""
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

    # Create filename with key parameters
    if model_type == "StaticDEQ":
        params_str = f"Static_H{static_hidden_dimension}_I{static_iterations}_W{static_num_weight_matrices}"
    elif model_type == "HierarchicalDEQ":
        params_str = f"Hier_H{hier_hidden_dimension}_S{len(hier_iterations_per_stage.split(','))}"
    else:  # HyperDEQ
        params_str = f"Hyper_H{hyper_hidden_dimension}_Heads{hyper_num_heads}"

    params_str += f"_bs{batch_size}_lr{learning_rate}_s{samples_processed}_{timestamp}"

    save_path = os.path.join(save_directory, params_str)

    # Save complete checkpoint
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'model_config': {
            'model_type': model_type,
            'T': T, 'C': C,
            'hidden_activation': hidden_activation,
            'output_activation': output_activation,
            'weight_initialization': weight_initialization,
            'use_bias': use_bias,
            'use_sudoku_positional': use_sudoku_positional
        },
        'training_config': {
            'max_training_samples': max_training_samples,
            'batch_size': batch_size,
            'learning_rate': learning_rate,
            'weight_decay': weight_decay,
            'optimizer_type': optimizer_type,
            'momentum_for_sgd': momentum_for_sgd,
            'loss_on_empty_cells_only': loss_on_empty_cells_only,
            'gradient_clip_value': gradient_clip_value,
            'samples_processed': samples_processed,
            'min_empty': min_empty,
            'max_empty': max_empty,
            'include_extreme': include_extreme
        },
        'history': history
    }, save_path + '_checkpoint.pt')

    print(f"Checkpoint saved to {save_path}_checkpoint.pt")

def train_model():
    """Main training loop"""
    # Create model and optimizer
    model = create_model()
    optimizer = create_optimizer(model)

    # Training history
    history = {
        'train_loss': [], 'test_loss': [], 'test_extreme_loss': [],
        'challenge_loss': [], 'nikoli_loss': [],
        'train_complete': [], 'test_complete': [], 'test_extreme_complete': [],
        'challenge_complete': [], 'nikoli_complete': [],
        'train_empty': [], 'test_empty': [], 'test_extreme_empty': [],
        'challenge_empty': [], 'nikoli_empty': [],
        'train_overall': [], 'test_overall': [], 'test_extreme_overall': [],
        'challenge_overall': [], 'nikoli_overall': []
    }

    # Create data loader for training
    train_loader = dataset.get_dataloader('train', batch_size=batch_size, shuffle=True,
                                          min_empty=min_empty, max_empty=max_empty,
                                          include_extreme=include_extreme)

    # Training loop
    model.train()
    samples_processed = 0
    batch_count = 0
    recent_losses = []
    recent_metrics = {'complete': [], 'empty': [], 'overall': []}

    print(f"Starting training for {max_training_samples} samples...")
    pbar = tqdm(total=max_training_samples,
            desc="Training",
            position=0,
            leave=True,
            ncols=100)

    while samples_processed < max_training_samples:
        for input_data, target_data in train_loader:
            if samples_processed >= max_training_samples:
                break

            input_data = input_data.to(device)
            target_data = target_data.to(device)

            # Forward pass
            output = model(input_data)

            # Compute loss
            input_empty = (input_data.sum(dim=-1) == 0).float()
            loss = compute_loss(output, target_data, input_empty, entropy_weight=entropy_weight)

            # Backward pass
            optimizer.zero_grad()
            loss.backward()

            # Gradient clipping
            if gradient_clip_value > 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clip_value)

            optimizer.step()

            # Track metrics
            recent_losses.append(loss.item())
            metrics = evaluate_accuracy(output.detach(), target_data, input_data)
            for key in metrics:
                recent_metrics[key].append(metrics[key])

            samples_processed += input_data.size(0)
            batch_count += 1
            pbar.update(input_data.size(0))

            # Logging
            if samples_processed % logging_interval_samples <= batch_size:
                avg_loss = np.mean(recent_losses)
                avg_metrics = {key: np.mean(values) for key, values in recent_metrics.items()}

                print(f"\nSamples: {samples_processed}/{max_training_samples}")
                print(f"Train Loss: {avg_loss:.4f}")
                print(f"Train Acc - Complete: {avg_metrics['complete']:.2%}, "
                      f"Empty: {avg_metrics['empty']:.2%}, "
                      f"Overall: {avg_metrics['overall']:.2%}")

                # Store training metrics
                history['train_loss'].append(avg_loss)
                for key in ['complete', 'empty', 'overall']:
                    history[f'train_{key}'].append(avg_metrics[key])

                # Reset recent tracking
                recent_losses = []
                recent_metrics = {'complete': [], 'empty': [], 'overall': []}

            # Testing
            if samples_processed % evaluation_interval_samples <= batch_size:
                print("\nEvaluating...")

                # Evaluate on all splits
                test_metrics, test_loss = evaluate_split(model, 'test', test_evaluation_samples, criterion=True)
                test_extreme_metrics, test_extreme_loss = evaluate_split(model, 'test_extreme', 100, criterion=True)
                challenge_metrics, challenge_loss = evaluate_split(model, 'challenge', challenge_evaluation_samples, criterion=True)
                nikoli_metrics, nikoli_loss = evaluate_split(model, 'nikoli', nikoli_evaluation_samples, criterion=True)

                # Store history
                history['test_loss'].append(test_loss)
                history['test_extreme_loss'].append(test_extreme_loss)
                history['challenge_loss'].append(challenge_loss)
                history['nikoli_loss'].append(nikoli_loss)

                for key in ['complete', 'empty', 'overall']:
                    history[f'test_{key}'].append(test_metrics[key])
                    history[f'test_extreme_{key}'].append(test_extreme_metrics[key])
                    history[f'challenge_{key}'].append(challenge_metrics[key])
                    history[f'nikoli_{key}'].append(nikoli_metrics[key])

                print(f"Test - Loss: {test_loss:.4f}, Complete: {test_metrics['complete']:.2%}, "
                      f"Empty: {test_metrics['empty']:.2%}, Overall: {test_metrics['overall']:.2%}")
                print(f"Test Extreme - Loss: {test_extreme_loss:.4f}, Complete: {test_extreme_metrics['complete']:.2%}, "
                      f"Empty: {test_extreme_metrics['empty']:.2%}, Overall: {test_extreme_metrics['overall']:.2%}")
                print(f"Challenge - Loss: {challenge_loss:.4f}, Complete: {challenge_metrics['complete']:.2%}, "
                      f"Empty: {challenge_metrics['empty']:.2%}, Overall: {challenge_metrics['overall']:.2%}")
                print(f"Nikoli - Loss: {nikoli_loss:.4f}, Complete: {nikoli_metrics['complete']:.2%}, "
                      f"Empty: {nikoli_metrics['empty']:.2%}, Overall: {nikoli_metrics['overall']:.2%}")

                model.train()

            # Save checkpoint
            if save_model and samples_processed % save_interval_samples <= batch_size:
                save_checkpoint(model, optimizer, history, samples_processed)

    pbar.close()
    print("Training completed!")

    return model, history

def plot_training_history(history):
    """Plot training history"""
    if not show_plots:
        return

    fig, axes = plt.subplots(2, 2, figsize=(15, 12))

    # Loss plot
    ax = axes[0, 0]
    if history['train_loss']:
        x_train = np.arange(1, len(history['train_loss']) + 1) * logging_interval_samples
        ax.plot(x_train, history['train_loss'], label='Train', marker='o', markersize=3)
    if history['test_loss']:
        x_test = np.arange(1, len(history['test_loss']) + 1) * evaluation_interval_samples
        ax.plot(x_test, history['test_loss'], label='Test', marker='o')
        ax.plot(x_test, history['test_extreme_loss'], label='Test Extreme', marker='d')
        ax.plot(x_test, history['challenge_loss'], label='Challenge', marker='s')
        ax.plot(x_test, history['nikoli_loss'], label='Nikoli', marker='^')
    ax.set_xlabel('Samples Processed')
    ax.set_ylabel('Loss')
    ax.set_title('Loss Over Training')
    ax.legend()
    ax.grid(True, alpha=0.3)

    # Complete accuracy plot
    ax = axes[0, 1]
    if history['train_complete']:
        x_train = np.arange(1, len(history['train_complete']) + 1) * logging_interval_samples
        ax.plot(x_train, history['train_complete'], label='Train', marker='o', markersize=3)
    if history['test_complete']:
        x_test = np.arange(1, len(history['test_complete']) + 1) * evaluation_interval_samples
        ax.plot(x_test, history['test_complete'], label='Test', marker='o')
        ax.plot(x_test, history['test_extreme_complete'], label='Test Extreme', marker='d')
        ax.plot(x_test, history['challenge_complete'], label='Challenge', marker='s')
        ax.plot(x_test, history['nikoli_complete'], label='Nikoli', marker='^')
    ax.set_xlabel('Samples Processed')
    ax.set_ylabel('Complete Accuracy')
    ax.set_title('Complete Puzzle Accuracy')
    ax.legend()
    ax.grid(True, alpha=0.3)

    # Empty cell accuracy plot
    ax = axes[1, 0]
    if history['train_empty']:
        x_train = np.arange(1, len(history['train_empty']) + 1) * logging_interval_samples
        ax.plot(x_train, history['train_empty'], label='Train', marker='o', markersize=3)
    if history['test_empty']:
        x_test = np.arange(1, len(history['test_empty']) + 1) * evaluation_interval_samples
        ax.plot(x_test, history['test_empty'], label='Test', marker='o')
        ax.plot(x_test, history['test_extreme_empty'], label='Test Extreme', marker='d')
        ax.plot(x_test, history['challenge_empty'], label='Challenge', marker='s')
        ax.plot(x_test, history['nikoli_empty'], label='Nikoli', marker='^')
    ax.set_xlabel('Samples Processed')
    ax.set_ylabel('Empty Cell Accuracy')
    ax.set_title('Empty Cell Accuracy')
    ax.legend()
    ax.grid(True, alpha=0.3)

    # Overall accuracy plot
    ax = axes[1, 1]
    if history['train_overall']:
        x_train = np.arange(1, len(history['train_overall']) + 1) * logging_interval_samples
        ax.plot(x_train, history['train_overall'], label='Train', marker='o', markersize=3)
    if history['test_overall']:
        x_test = np.arange(1, len(history['test_overall']) + 1) * evaluation_interval_samples
        ax.plot(x_test, history['test_overall'], label='Test', marker='o')
        ax.plot(x_test, history['test_extreme_overall'], label='Test Extreme', marker='d')
        ax.plot(x_test, history['challenge_overall'], label='Challenge', marker='s')
        ax.plot(x_test, history['nikoli_overall'], label='Nikoli', marker='^')
    ax.set_xlabel('Samples Processed')
    ax.set_ylabel('Overall Accuracy')
    ax.set_title('Overall Accuracy')
    ax.legend()
    ax.grid(True, alpha=0.3)

    plt.tight_layout()

    if save_plots:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        plt.savefig(os.path.join(save_directory, f'training_history_{timestamp}.png'))

    plt.show()

def plot_empty_distribution(model):
    """Plot accuracy distribution by number of empty cells"""
    if not show_plots:
        return

    print("\nEvaluating accuracy by number of empty cells...")

    # Get distributions for train and test
    train_dist = evaluate_empty_distribution(model, 'train')
    test_dist = evaluate_empty_distribution(model, 'test')

    if not train_dist and not test_dist:
        print("No distribution data available")
        return

    fig, axes = plt.subplots(1, 3, figsize=(18, 6))

    # Prepare data
    empty_counts = sorted(set(list(train_dist.keys()) + list(test_dist.keys())))

    # Complete accuracy
    ax = axes[0]
    train_complete = [train_dist.get(ec, {}).get('complete', 0) for ec in empty_counts]
    test_complete = [test_dist.get(ec, {}).get('complete', 0) for ec in empty_counts]

    x = np.arange(len(empty_counts))
    width = 0.35

    ax.bar(x - width/2, train_complete, width, label='Train', alpha=0.8)
    ax.bar(x + width/2, test_complete, width, label='Test', alpha=0.8)
    ax.set_xlabel('Number of Empty Cells')
    ax.set_ylabel('Complete Accuracy')
    ax.set_title('Complete Puzzle Accuracy by Empty Cells')
    ax.set_xticks(x[::5])
    ax.set_xticklabels(empty_counts[::5])
    ax.legend()
    ax.grid(True, alpha=0.3)

    # Empty cell accuracy
    ax = axes[1]
    train_empty = [train_dist.get(ec, {}).get('empty', 0) for ec in empty_counts]
    test_empty = [test_dist.get(ec, {}).get('empty', 0) for ec in empty_counts]

    ax.bar(x - width/2, train_empty, width, label='Train', alpha=0.8)
    ax.bar(x + width/2, test_empty, width, label='Test', alpha=0.8)
    ax.set_xlabel('Number of Empty Cells')
    ax.set_ylabel('Empty Cell Accuracy')
    ax.set_title('Empty Cell Accuracy by Empty Cells')
    ax.set_xticks(x[::5])
    ax.set_xticklabels(empty_counts[::5])
    ax.legend()
    ax.grid(True, alpha=0.3)

    # Overall accuracy
    ax = axes[2]
    train_overall = [train_dist.get(ec, {}).get('overall', 0) for ec in empty_counts]
    test_overall = [test_dist.get(ec, {}).get('overall', 0) for ec in empty_counts]

    ax.bar(x - width/2, train_overall, width, label='Train', alpha=0.8)
    ax.bar(x + width/2, test_overall, width, label='Test', alpha=0.8)
    ax.set_xlabel('Number of Empty Cells')
    ax.set_ylabel('Overall Accuracy')
    ax.set_title('Overall Accuracy by Empty Cells')
    ax.set_xticks(x[::5])
    ax.set_xticklabels(empty_counts[::5])
    ax.legend()
    ax.grid(True, alpha=0.3)

    plt.tight_layout()

    if save_plots:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        plt.savefig(os.path.join(save_directory, f'empty_distribution_{timestamp}.png'))

    plt.show()

def display_predictions(model, num_samples=None):
    """Display sample predictions from each split"""
    if not show_plots:
        return

    if num_samples is None:
        num_samples = num_prediction_samples_to_show

    model.eval()

    splits = ['train', 'test', 'test_extreme', 'challenge', 'nikoli']

    fig, axes = plt.subplots(len(splits), num_samples * 3, figsize=(3 * num_samples * 3, 3 * len(splits)))

    # Handle case where axes might not be 2D
    if len(splits) == 1:
        axes = axes.reshape(1, -1)
    if num_samples == 1:
        axes = axes.reshape(-1, 3)

    for split_idx, split_name in enumerate(splits):
        # Use batch_size=1 to handle variable sizes
        # For train and test, use the specified parameters
        if split_name in ['train', 'test']:
            dataloader = dataset.get_dataloader(split_name, batch_size=1, shuffle=True,
                                               min_empty=min_empty, max_empty=max_empty,
                                               include_extreme=include_extreme)
        else:
            # For test_extreme, challenge, nikoli, no need for empty range parameters
            dataloader = dataset.get_dataloader(split_name, batch_size=1, shuffle=True)

        samples_shown = 0
        with torch.no_grad():
            for input_data, target_data in dataloader:
                if samples_shown >= num_samples:
                    break

                # Skip non-81 cell puzzles
                if input_data.shape[1] != 81:
                    continue

                input_data = input_data.to(device)
                output = model(input_data)

                # Input
                ax = axes[split_idx, samples_shown * 3]
                ax.set_title(f'{split_name} - Input {samples_shown+1}')
                ax.text(0.5, 0.5, sample_to_grid(input_data[0].cpu(), 'tokenized'),
                       ha='center', va='center', fontfamily='monospace', fontsize=8)
                ax.axis('off')

                # Prediction
                ax = axes[split_idx, samples_shown * 3 + 1]
                ax.set_title(f'{split_name} - Prediction {samples_shown+1}')
                ax.text(0.5, 0.5, sample_to_grid(output[0].cpu(), 'tokenized'),
                       ha='center', va='center', fontfamily='monospace', fontsize=8)
                ax.axis('off')

                # Target
                ax = axes[split_idx, samples_shown * 3 + 2]
                ax.set_title(f'{split_name} - Target {samples_shown+1}')
                ax.text(0.5, 0.5, sample_to_grid(target_data[0].cpu(), 'tokenized'),
                       ha='center', va='center', fontfamily='monospace', fontsize=8)
                ax.axis('off')

                samples_shown += 1

    plt.tight_layout()

    if save_plots:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        plt.savefig(os.path.join(save_directory, f'predictions_{timestamp}.png'))

    plt.show()