In [None]:
def predict_token(model, puzzles, positions, device='cpu'):
    """
    Predict the token at given positions in sudoku puzzles.
    
    Args:
        model: The Sudoku2Vec model
        puzzles: Sudoku grids [batch_size, 9, 9]
        positions: The positions [batch_size, 2] as [x, y]
        device: The device to use for computation
    
    Returns:
        predicted_tokens: The predicted tokens (1-9) based on closest embedding [batch_size]
    """
    model.eval()
    
    with torch.no_grad():
        # Prepare inputs - handle both tensors and arrays
        if isinstance(positions, torch.Tensor):
            position_batch = positions.to(device=device, dtype=torch.long)
        else:
            position_batch = torch.tensor(positions, dtype=torch.long, device=device)
        
        if isinstance(puzzles, torch.Tensor):
            puzzle_batch = puzzles.to(device=device, dtype=torch.long)
        else:
            puzzle_batch = torch.tensor(puzzles, dtype=torch.long, device=device)
        
        batch_size = puzzle_batch.shape[0]
        
        # Create dummy target tokens (we'll ignore these, just need them for forward pass)
        dummy_target = torch.zeros(batch_size, dtype=torch.long, device=device)
        
        # Use the model's forward method
        output, attention, target_token_with_position, _ = model.forward(dummy_target, position_batch, puzzle_batch, mask=True)
        # print(f"DEBUG: output.shape = {output.shape}")
        # print(f"DEBUG: attention.shape = {attention.shape}")
        # print(f"DEBUG: target_token_with_position.shape = {target_token_with_position.shape}")
        
        # Get the attended output at the target position
        # We need to extract the embedding at the target position from the output
        target_indices = position_batch[:, 1] * 9 + position_batch[:, 0]  # Convert 2D position to 1D index
        # # print(f"DEBUG: target_indices = {target_indices}")
        predicted_embedding = output[torch.arange(batch_size, device=device), target_indices]  # [batch_size, total_dim]
        # print(f"DEBUG: predicted_embedding.shape = {predicted_embedding.shape}")
        
        # Get all token embeddings (0-9)
        all_tokens = torch.arange(0, 10, dtype=torch.long, device=device)
        all_token_embeddings = model.embed(all_tokens)  # [10, embedding_dim]
        # print(f"DEBUG: all_token_embeddings.shape = {all_token_embeddings.shape}")
        # print(f"DEBUG: model.embedding_dim = {model.embedding_dim}")
        
        # Extract only the token embedding part (exclude position info)
        # predicted_embedding is [batch_size, total_dim] where total_dim = embedding_dim + 2
        # We only want the first embedding_dim dimensions
        predicted_token_part = predicted_embedding[:, :model.embedding_dim]  # [batch_size, embedding_dim]
        
        # Find closest embedding using cosine similarity
        # Normalize embeddings
        predicted_embedding_norm = torch.nn.functional.normalize(predicted_token_part, p=2, dim=-1)  # [batch_size, embedding_dim]
        all_token_embeddings_norm = torch.nn.functional.normalize(all_token_embeddings, p=2, dim=-1)  # [10, embedding_dim]
        
        # Compute cosine similarity
        similarities = torch.matmul(predicted_embedding_norm, all_token_embeddings_norm.T)  # [batch_size, 10]
        predicted_tokens = torch.argmax(similarities, dim=-1)  # tokens are 0-9, [batch_size]
        
        return predicted_tokens

# Test model accuracy
# Get predictions for all validation samples in batch
TEST_SIZE = 100
test_target, test_position, test_puzzles, test_original_puzzle = generator.generate_target_context_pairs(size=TEST_SIZE)
print(test_target)
all_predictions = predict_token(model, test_puzzles, test_position, device=device)

# Get actual values at the positions
all_actuals = []
for i in range(len(test_puzzles)):
    pos_x, pos_y = test_position[i][0].item(), test_position[i][1].item()
    actual_token = test_puzzles[i][pos_y, pos_x].item()
    all_actuals.append(actual_token)

all_actuals = torch.tensor(all_actuals, device=device)

# Calculate accuracy
correct = (all_predictions == all_actuals).sum().item()
total = len(all_predictions)
accuracy = correct / total
print(f"Validation Accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")
print(f"Correct: {correct}/{total}")

# Create confusion matrix manually
# Tokens are 1-9
cm = np.zeros((9, 9), dtype=int)
all_predictions_cpu = all_predictions.cpu().numpy()
all_actuals_cpu = all_actuals.cpu().numpy()

for actual, pred in zip(all_actuals_cpu, all_predictions_cpu):
    cm[actual - 1, pred - 1] += 1

# Plot confusion matrix using matplotlib
fig, ax = plt.subplots(figsize=(10, 8))
im = ax.imshow(cm, cmap='Blues', aspect='auto')

# Add colorbar
cbar = plt.colorbar(im, ax=ax)
cbar.set_label('Count', rotation=270, labelpad=20)

# Set ticks and labels
ax.set_xticks(np.arange(9))
ax.set_yticks(np.arange(9))
ax.set_xticklabels(list(range(1, 10)))
ax.set_yticklabels(list(range(1, 10)))

# Add text annotations
for i in range(9):
    for j in range(9):
        text = ax.text(j, i, cm[i, j], ha="center", va="center", 
                      color="black" if cm[i, j] < cm.max() / 2 else "white")

# Labels and title
ax.set_xlabel('Predicted Token', fontsize=12)
ax.set_ylabel('Actual Token', fontsize=12)
ax.set_title(f'Confusion Matrix (Accuracy: {accuracy:.4f})', fontsize=14)

plt.tight_layout()
plt.show()
def visualize_attention(model, puzzle, target, position, device='cpu'):
    """
    Visualize the attention weights for each head in the model.
    
    Args:
        model: The Sudoku2Vec model
        puzzle: A single sudoku grid [9, 9]
        target: The target token (scalar)
        position: The position [2] as [x, y]
        device: The device to use for computation
    """
    model.eval()
    
    with torch.no_grad():
        # Prepare inputs (add batch dimension)
        target_batch = torch.tensor([target], dtype=torch.long, device=device)
        position_batch = torch.tensor([position], dtype=torch.long, device=device)
        puzzle_batch = torch.tensor([puzzle], dtype=torch.long, device=device)
        
        # Get embeddings
        batch_size = 1
        target_token_embeddings = model.embed(target_batch)
        target_position_vectors = model.pe.get_embedding_for_position(position_batch)
        target_token_with_position = torch.cat([target_token_embeddings, target_position_vectors], dim=-1)
        
        # Mask the target in the grid
        sudoku_grid_masked = puzzle_batch.clone()
        sudoku_grid_masked[0, position[1], position[0]] = 0
        
        # Get grid embeddings
        sudoku_grid_embeddings = model.embed(sudoku_grid_masked)
        sudoku_grid_with_position = model.pe(sudoku_grid_embeddings)
        
        # Reshape grid to sequence
        grid_seq = sudoku_grid_with_position.view(batch_size, 81, model.total_dim)
        
        # Get attention weights from the model
        _, attention = model.mha(grid_seq, return_attention=True)
        
        # Reshape attention weights to grid format
        # attention shape is [batch, num_heads, seq_len, seq_len]
        # Get the attention for the specific position (x, y)
        position_idx = position[1] * 9 + position[0]  # Convert 2D position to 1D index
        attention_weights = attention.squeeze(0)[:, position_idx, :]  # [num_heads, 81]
        
        # Validate that attention sums to 1 for each head
        attention_sums = attention_weights.sum(dim=-1)
        print(f"Attention sums per head: {attention_sums.cpu().float().numpy()}")
        assert torch.allclose(attention_sums, torch.ones_like(attention_sums), atol=1e-5), \
            f"Attention weights do not sum to 1! Sums: {attention_sums}"
        
        attention_grids = attention_weights.view(model.num_heads, 9, 9).cpu().float().numpy()
        # Plot attention maps for each head
        fig, axes = plt.subplots(2, (model.num_heads + 1) // 2, figsize=(15, 6))
        axes = axes.flatten()
        
        for head_idx in range(model.num_heads):
            ax = axes[head_idx]
            im = ax.imshow(attention_grids[head_idx], cmap='hot', interpolation='nearest')
            ax.set_title(f'Head {head_idx + 1}')
            ax.set_xlabel('Column')
            ax.set_ylabel('Row')
            
            # Add grid values as text
            for i in range(9):
                for j in range(9):
                    text_color = 'white' if attention_grids[head_idx, i, j] > 0.5 else 'black'
                    ax.text(j, i, f'{puzzle[i, j]}', ha='center', va='center', 
                           color=text_color, fontsize=8, weight='bold')
            
            # Mark target position
            ax.plot(position[0], position[1], 'b*', markersize=15, markeredgecolor='cyan', markeredgewidth=2)
            
            plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
        
        plt.suptitle(f'Attention Maps for Target={target} at Position ({position[0]}, {position[1]})', fontsize=14)
        plt.tight_layout()
        plt.show()

# Example usage
example_idx = random.randint(0, 100)
example_puzzle = val_puzzles[example_idx].cpu().numpy()
example_target = val_target[example_idx].item()
example_position = val_position[example_idx].cpu().numpy()

print(f"\n=== Visualizing Attention ===")
print(f"Puzzle:\n{example_puzzle}")
print(f"Target: {example_target} at position {example_position}")

# Get prediction
predicted_token = predict_token(model, [example_puzzle], [example_position], device=device)[0].item()
print(f"Predicted: {predicted_token}, Actual: {example_target}")

visualize_attention(model, example_puzzle, example_target, example_position, device=device)

# DEBUG MODE TEST: Run loss() with debug=True on one input
print("="*80)
print("TESTING DEBUG MODE")
print("="*80)

# Select one sample from the validation set
debug_sample_idx = 0
debug_target = val_target[debug_sample_idx:debug_sample_idx+1]
debug_position = val_position[debug_sample_idx:debug_sample_idx+1]
debug_puzzle = val_puzzles[debug_sample_idx:debug_sample_idx+1]

print(f"\nRunning loss() with debug=True on one validation sample...")
print(f"Target: {debug_target.item()}")
print(f"Position: ({debug_position[0, 0].item()}, {debug_position[0, 1].item()})")
print(f"Puzzle shape: {debug_puzzle.shape}")

# Set model to eval mode
model.eval()

# Run loss with debug=True
with torch.no_grad():
    loss, contrastive_loss, cosine_loss, accuracy = model.loss(
        debug_target.long().to(device), 
        debug_position.long().to(device), 
        debug_puzzle.long().to(device),
        negative_samples=2,
        debug=True
    )

print(f"\nLoss Computation Complete:")
print(f"  Total Loss: {loss.item():.4f}")
print(f"  Contrastive Loss: {contrastive_loss.item():.4f}")
print(f"  Cosine Loss: {cosine_loss.item():.4f}")
print(f"  Accuracy: {accuracy.item():.4f} ({accuracy.item()*100:.2f}%)")
print("="*80)

# Visualize embeddings after training
print("="*80)
print("VISUALIZING EMBEDDINGS AFTER TRAINING")
print("="*80)

# Get all token embeddings (0-9)
model.eval()
with torch.no_grad():
    all_tokens = torch.arange(0, 10, dtype=torch.long, device=device)
    all_embeddings = model.embed(all_tokens).float().cpu().numpy()  # [10, embedding_dim]

print(f"Embedding shape: {all_embeddings.shape}")
print(f"Embedding dimension: {model.embedding_dim}")

# Create separate plots for each token
embedding_dim = all_embeddings.shape[1]
colors = plt.cm.tab10(np.linspace(0, 1, 10))

# Create a grid of subplots (2 rows x 5 columns for 10 tokens)
fig, axes = plt.subplots(2, 5, figsize=(20, 8))
axes = axes.flatten()

for token_idx in range(10):
    ax = axes[token_idx]
    x = np.arange(embedding_dim)
    
    # Plot bar chart for this token's embedding
    ax.bar(x, all_embeddings[token_idx], color=colors[token_idx], alpha=0.8)
    
    ax.set_xlabel('Dimension', fontsize=10)
    ax.set_ylabel('Value', fontsize=10)
    ax.set_title(f'Token {token_idx} Embedding', fontsize=12, fontweight='bold')
    ax.set_xticks(x)
    ax.set_xticklabels([f'{i}' for i in range(embedding_dim)], fontsize=8)
    ax.grid(True, alpha=0.3, axis='y')
    ax.axhline(y=0, color='black', linestyle='-', linewidth=0.5)
    
    # Add statistics to each plot
    mean_val = all_embeddings[token_idx].mean()
    std_val = all_embeddings[token_idx].std()
    ax.text(0.02, 0.98, f'μ={mean_val:.2f}, σ={std_val:.2f}', 
            transform=ax.transAxes, fontsize=9, verticalalignment='top',
            bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

plt.tight_layout()
plt.show()

# Print embedding statistics
print("\nEmbedding Statistics:")
print(f"  Mean value: {all_embeddings.mean():.4f}")
print(f"  Std value: {all_embeddings.std():.4f}")
print(f"  Min value: {all_embeddings.min():.4f}")
print(f"  Max value: {all_embeddings.max():.4f}")
print(f"  Mean norm per token: {np.linalg.norm(all_embeddings, axis=1).mean():.4f}")
print("="*80)


