## Summary

In this notebook we load a network trained to solve Sudoku puzzles and use this network to solve a single Sudoku.

----

## Imports

In [None]:
import functools
import io
import os
import sys
import tempfile
import time
from collections import deque
from pathlib import Path

import ipywidgets as widgets
import numpy as np
import pandas as pd
import tqdm
from IPython.display import HTML, display
from ipywidgets import fixed, interact, interact_manual, interactive

import matplotlib as mpl
import matplotlib.pyplot as plt
import pyarrow
import torch
import torch.nn as nn
from matplotlib import cm
from torch_geometric.data import DataLoader

In [None]:
import proteinsolver
import proteinsolver.datasets
from proteinsolver.utils import gen_sudoku_graph_featured

In [None]:
%matplotlib agg

try:
    inline_rc
except NameError:
    inline_rc = mpl.rcParams.copy()
    
mpl.rcParams.update({"font.size": 12})

## Parameters

In [None]:
UNIQUE_ID = "c8de7e56"

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
device

## Load model

In [None]:
%run sudoku_train/{UNIQUE_ID}/model.py

In [None]:
state_files = sorted(
    Path("sudoku_train").joinpath(UNIQUE_ID).glob("*.state"),
    key=lambda s: (int(s.stem.split("-")[3].strip("amv")), int(s.stem.split("-")[2].strip("d"))),
)

In [None]:
state_file = state_files[-1]

In [None]:
net = Net(
    x_input_size=13, adj_input_size=3, hidden_size=162, output_size=9, batch_size=8
).to(device)

net.load_state_dict(torch.load(state_file, map_location=device))
net = net.eval()
net = net.to(device)

## Define widgets

### Sudoku grid

In [None]:
sudoku_widget_lookup = [[None for _ in range(9)] for _ in range(9)]

row_widgets = []
for row in range(3):
    col_widgets = []
    for col in range(3):
        subrow_widgets = []
        for subrow in range(3):
            subcol_widgets = []
            for subcol in range(3):
                i = row * 3 + subrow
                j = col * 3 + subcol
                subcol_widget = (
                    widgets.BoundedIntText(
                        value=0,
                        min=0,
                        max=9,
                        step=1,
                        description='',
                        disabled=False,
                        allow_none=True,
                        layout={"width": "42px"}
                    )
                )
                subcol_widgets.append(subcol_widget)
                sudoku_widget_lookup[i][j] = subcol_widget
            subrow_widget = widgets.HBox(subcol_widgets)
            subrow_widgets.append(subrow_widget)
        col_widget = widgets.VBox(subrow_widgets, layout={"padding": "5px"})
        col_widgets.append(col_widget)
    row_widget = widgets.HBox(col_widgets)    
    row_widgets.append(row_widget)

sudoku_widget = widgets.VBox(row_widgets)

### Puzzle selector

In [None]:
puzzle_0 = torch.zeros(9, 9, dtype=torch.int64)

In [None]:
puzzle_1 = torch.tensor(
    [
        [0, 8, 0, 0, 3, 2, 0, 0, 1],
        [7, 0, 3, 0, 8, 0, 0, 0, 2],
        [5, 0, 0, 0, 0, 7, 0, 3, 0],
        [0, 5, 0, 0, 0, 1, 9, 7, 0],
        [6, 0, 0, 7, 0, 9, 0, 0, 8],
        [0, 4, 7, 2, 0, 0, 0, 5, 0],
        [0, 2, 0, 6, 0, 0, 0, 0, 9],
        [8, 0, 0, 0, 9, 0, 3, 0, 5],
        [3, 0, 0, 8, 2, 0, 0, 1, 0],
    ]
)

In [None]:
buf = io.StringIO()
buf.write("""\
		6	3	4	8	9	1	
				6		4		8
								
			6			7		9
	6		8	9	1	2	4	5
	9		2	7		1		3
		1	4	5			9	
		2	9		6		3	
9		5				6		
""")
buf.seek(0)
df = pd.read_csv(buf, sep="\t", names=list(range(9))).fillna(0).astype(int)
puzzle_2 = torch.from_numpy(df.values)

In [None]:
buf = io.StringIO()
buf.write("""\
,,,,7,,,2,
,,,,4,,7,,
,,9,,3,6,1,4,
1,,3,4,5,,8,9,
4,,7,6,,,2,,
,,8,,1,,,,
3,,2,5,6,,,,8
8,,,,,3,,6,4
9,6,,,,4,,1,
""")
buf.seek(0)
df = pd.read_csv(buf, names=list(range(9))).fillna(0).astype(int)
puzzle_3 = torch.from_numpy(df.values)

In [None]:
def empty_out_puzzle(b, puzzle_matrix):
    for i in range(9):
        for j in range(9):
            sudoku_widget_lookup[i][j].value = puzzle_matrix[i][j]

In [None]:
empty_puzzle_button = widgets.Button(
    description="Empty",
    disabled=False,
    button_style="",  # 'success', 'info', 'warning', 'danger' or ''
    tooltip="Click me to set Sudoku grid to empty.",
    #     icon='puzzle-piece'
)
empty_puzzle_button.on_click(functools.partial(empty_out_puzzle, puzzle_matrix=puzzle_0))


test_puzzle_1_button = widgets.Button(
    description="Puzzle 1",
    disabled=False,
    button_style="",  # 'success', 'info', 'warning', 'danger' or ''
    tooltip="Click me to set Sudoku grid to puzzle 1",
    icon="puzzle-piece",
)
test_puzzle_1_button.on_click(functools.partial(empty_out_puzzle, puzzle_matrix=puzzle_1))


test_puzzle_2_button = widgets.Button(
    description="Puzzle 2",
    disabled=False,
    button_style="",  # 'success', 'info', 'warning', 'danger' or ''
    tooltip="Click me to set Sudoku grid to puzzle 2",
    icon="puzzle-piece",
)
test_puzzle_2_button.on_click(functools.partial(empty_out_puzzle, puzzle_matrix=puzzle_2))


test_puzzle_3_button = widgets.Button(
    description="Puzzle 3",
    disabled=False,
    button_style="",  # 'success', 'info', 'warning', 'danger' or ''
    tooltip="Click me to set Sudoku grid to puzzle 3",
    icon="puzzle-piece",
    #     layout={"margin": "10px"}
)
test_puzzle_3_button.on_click(functools.partial(empty_out_puzzle, puzzle_matrix=puzzle_3))


puzzle_selector_widget = widgets.HBox(
    [empty_puzzle_button, test_puzzle_1_button, test_puzzle_2_button, test_puzzle_3_button]
)

In [None]:
empty_out_puzzle(None, puzzle_1)

### Puzzle solver

In [None]:
def encode_puzzle(puzzle):
    puzzle = puzzle - 1
    puzzle = torch.where(puzzle >= 0, puzzle, torch.tensor(9))
    return puzzle


def decode_puzzle(puzzle):
    puzzle = (puzzle + 1) % 10
    return puzzle


puzzle = torch.tensor([1, 1, 1])
assert torch.equal(decode_puzzle(encode_puzzle(puzzle)), puzzle)

In [None]:
def solve_sudoku(net, puzzle):
    sudoku_graph = torch.from_numpy(gen_sudoku_graph_featured()).to_sparse(2)
    edge_index = sudoku_graph.indices()
    edge_attr = sudoku_graph.values()

    output = net(
        encode_puzzle(puzzle).view(-1).to(device), edge_index.clone().to(device), edge_attr.clone().to(device)
    ).to("cpu")
    output = torch.softmax(output, dim=1)
    _, predicted = output.max(dim=1)

    return decode_puzzle(predicted).reshape(9, 9)

In [None]:
def show_sudoku(puzzle, solved=None, pred=None, title="", color="black", ax=None):
    # Simple plotting statement that ingests a 9x9 array (n), and plots a sudoku-style grid around it.
    
    if ax is None:
        fg, ax = plt.subplots(figsize=(4.8, 4.8))

    for y in range(10):
        ax.plot([-0.05, 9.05], [y, y], color="black", linewidth=1)

    for y in range(0, 10, 3):
        ax.plot([-0.05, 9.05], [y, y], color="black", linewidth=3)

    for x in range(10):
        ax.plot([x, x], [-0.05, 9.05], color="black", linewidth=1)

    for x in range(0, 10, 3):
        ax.plot([x, x], [-0.05, 9.05], color="black", linewidth=3)

    ax.axis("image")
    ax.axis("off")  # drop the axes, they're not important here

#     if title is not None:
    ax.set_title(title, fontsize=20)

    for x in range(9):
        for y in range(9):
            puzzle_element = puzzle[8 - y][x]  # need to reverse the y-direction for plotting
            if puzzle_element > 0:  # ignore the zeros
                T = f"{puzzle_element}"
                ax.text(x + 0.25, y + 0.22, T, fontsize=20, color=color)
            elif solved is not None and pred is not None:
                solved_element = solved[8 - y][x]
                pred_element = pred[8 - y][x]
                if solved_element == pred_element:
                    T = f"{solved_element}"
                    ax.text(x + 0.25, y + 0.22, T, fontsize=20, color="C0")
                else:
                    ax.text(x + 0.1, y + 0.3, f"{pred_element}", fontsize=13, color="C3")
                    ax.text(x + 0.55, y + 0.3, f"{solved_element}", fontsize=13, color="C2")

    return ax

In [None]:
def plot_no_conflicts(title="", ax=None):
    if ax is None:
        fg, ax = plt.subplots(figsize=(4.8, 4.8))
    ax.axis("image")
    ax.axis("off")  # drop the axes, they're not important here
    ax.text(
        0.5,
        0.5,
        "No conflicts!",
        fontsize=20,
        fontdict={"horizontalalignment": "center", "color": "C2"},
        transform=ax.transAxes,
    )
    ax.set_title(title, fontsize=20)
    return ax

In [None]:
plot_no_conflicts(title="Conflict")

In [None]:
show_sudoku(puzzle_1, title="Input")

In [None]:
show_sudoku(puzzle_1, puzzle_1, puzzle_1, title="Solution")

In [None]:
def find_conflict(puzzle):
    for row_idx in range(9):
        for value in range(1, 10):
            mask = puzzle[row_idx, :] == value
            if mask.sum() > 1:
                ref = puzzle[row_idx, mask]
                puzzle = torch.zeros_like(puzzle)
                puzzle[row_idx, mask] = ref
                return puzzle

    for col_idx in range(9):
        for value in range(1, 10):
            mask = puzzle[:, col_idx] == value
            if mask.sum() > 1:
                ref = puzzle[mask, col_idx]
                puzzle = torch.zeros_like(puzzle)
                puzzle[mask, col_idx] = ref
                return puzzle

    for row_start_idx in range(0, 9, 3):
        for col_start_idx in range(0, 9, 3):
            for value in range(1, 10):
                mask = puzzle[row_start_idx : row_start_idx + 3, col_start_idx : col_start_idx + 3] == value
                if mask.sum() > 1:
                    ref = puzzle[row_start_idx : row_start_idx + 3, col_start_idx : col_start_idx + 3][mask]
                    puzzle = torch.zeros_like(puzzle)
                    puzzle[row_start_idx : row_start_idx + 3, col_start_idx : col_start_idx + 3][mask] = ref
                    return puzzle

    return None

In [None]:
puzzle = puzzle_1.clone()
puzzle[0, 0] = 7
find_conflict(puzzle)

In [None]:
puzzle = puzzle_1.clone()
puzzle[0, 2] = 8
find_conflict(puzzle)

In [None]:
puzzle = puzzle_1.clone()
puzzle[6, 2] = 8
find_conflict(puzzle)

In [None]:
def plot_solution(puzzle, solution):
    fg, axs = plt.subplots(1, 2, figsize=(9.8, 5))
    show_sudoku(puzzle, solution, solution, title="Solution", ax=axs[0])
    puzzle_conflict = find_conflict(solution)
    if puzzle_conflict is not None:
        show_sudoku(puzzle_conflict, title="Conflict", color="C3", ax=axs[1])
    else:
        plot_no_conflicts(title="Conflicts", ax=axs[1])
    return fg

In [None]:
_ = plot_solution(puzzle_0, puzzle_0)

In [None]:
solution_output_widget = widgets.Output(layout={'border': '1px solid black', "width": "600px"})

def solve_sudoku_from_widget(b):
    puzzle = torch.zeros(9, 9, dtype=torch.int64)
    for i in range(9):
        for j in range(9):
            puzzle[i][j] = sudoku_widget_lookup[i][j].value
    solution = solve_sudoku(net, puzzle)
    with solution_output_widget:
        solution_output_widget.clear_output()
        fg = plot_solution(puzzle, solution)
        display(fg)

In [None]:
solve_button_widget = widgets.Button(
    description='Solve!',
    disabled=False,
    button_style='', # 'success', 'info', 'warning', 'danger' or ''
    tooltip='Click to solve the Sudoku puzzle',
    icon='check'
)
solve_button_widget.on_click(solve_sudoku_from_widget)

In [None]:
solution = torch.tensor(
    [
        [4, 8, 9, 5, 3, 2, 7, 6, 1],
        [7, 1, 3, 4, 8, 6, 5, 9, 2],
        [5, 6, 2, 9, 1, 7, 8, 3, 4],
        [2, 5, 8, 3, 4, 1, 9, 7, 6],
        [6, 3, 1, 7, 5, 9, 2, 4, 8],
        [9, 4, 7, 2, 6, 8, 1, 5, 3],
        [1, 2, 5, 6, 7, 3, 4, 8, 9],
        [8, 7, 6, 1, 9, 4, 3, 2, 5],
        [3, 9, 4, 8, 2, 5, 6, 1, 7],
    ]
)

In [None]:
assert proteinsolver.utils.sudoku.sudoku_is_solved(solution.to("cpu"))

## Dashboard

## Solve a custom Sudoku puzzle

In [None]:
display(puzzle_selector_widget)

In [None]:
display(sudoku_widget)

In [None]:
display(solve_button_widget)

In [None]:
display(solution_output_widget)

In [None]:
display(HTML("""\
<hr>
<p>Running into issues? Please send an email to <a href="help@proteinsolver.org">help@proteinsolver.org</a>.
<br>
<em>This website works best using the latest versions of Firefox or Chrome web browsers.</em>
</p>
"""))