In [3]:
# Import necessary packages
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import os
import numpy as np
import torch
import torchvision
import matplotlib.pyplot as plt
from time import time

## Saving a classifier

In this notebook, we will use the classifier that you built in p1.

Hence, first go to that notebook and _export_ the classifier you built there, by adding the following code in that notebook:


In [None]:
# TODO: Maxime?

## Loading a pre-trained classifier

Now, we can load that pre-trained classifier in this notebook as follows:

In [None]:
# TODO: Maxime?

## Recap: solving a sudoku based on the predictions

In the following, we repeat the code of the previous notebook for sampling a sudoku and getting predictions.

We also included example _ortools_ code that solves the sudoku problem _(requires to install ortools, e.g. conda install ortools)_

In [4]:
# sudoku's, from http://hakank.org/minizinc/sudoku_problems2/index.html

sudoku_p0 = torch.IntTensor([[0,0,0, 2,0,5, 0,0,0],
                             [0,9,0, 0,0,0, 7,3,0],
                             [0,0,2, 0,0,9, 0,6,0],
                             [2,0,0, 0,0,0, 4,0,9],
                             [0,0,0, 0,7,0, 0,0,0],
                             [6,0,9, 0,0,0, 0,0,1],
                             [0,8,0, 4,0,0, 1,0,0],
                             [0,6,3, 0,0,0, 0,8,0],
                             [0,0,0, 6,0,8, 0,0,0]])

# sample a dataset index with that value/label
def sample_by_label(dataset, value):
    # primitive but it works...
    idxs = torch.randperm(len(dataset))
    for idx in idxs:
        if dataset.targets[idx] == value:
            return idx
# sample a dataset index for each non-zero number
def sample_visual_sudoku(sudoku_p, dataset):
    nonzero = sudoku_p > 0
    vizsudoku = -torch.ones(sudoku_p.shape, dtype=torch.long)
    vizsudoku[nonzero] = torch.LongTensor([sample_by_label(trainset, value) for value in sudoku_p[nonzero]])
    return vizsudoku
# get predictions
def predict_sudoku(model, vizsudoku_idx, dataset):
    nonzero = vizsudoku_idx > -1
    predsudoku = torch.zeros(vizsudoku_idx.shape, dtype=torch.int32)
    images = trainset.data[vizsudoku_idx[nonzero]]
    images = images.view(images.shape[0], -1).type(torch.float)
    # images.shape (23,784)
    with torch.no_grad():
        probs = model(images).exp()
        preds = torch.argmax(probs, dim=1)
        predsudoku[nonzero] = preds
    return predsudoku.reshape((9,9))

vizsudoku_idx = sample_visual_sudoku(sudoku_p0, trainset)
preds = predict_sudoku(model, vizsudoku_idx, trainset)

NameError: name 'trainset' is not defined

In [15]:
from ortools.sat.python import cp_model

# model and solve a sudoku with ortools
def model_sudoku_ort(grid):
        csp = cp_model.CpModel()

        # init vars
        board = [[csp.NewIntVar(1, 9, 'x_%i%i' % (i,j)) for j in range(9)] for i in range(9)]
        
        # assign knowns
        for i in range(9):
            for j in range(9):
                if preds[i,j] != 0:
                    csp.Add(board[i,j] == preds[i,j])
        
        
        for i in range(9):
            csp.AddAllDifferent(board[i]) # all different rows
            csp.AddAllDifferent([board[k][i] for k in range(9)]) # all different columns
        
        # all different cells
        for si in range(3):
            for sj in range(3):
                csp.AddAllDifferent([board[3*si+i][3*sj+j] for j in range(3) for i in range(3)])
    
        return (board,csp)

def solve_sudoku_ort(grid):
    # the constraint model
    (board,csp) = model_sudoku_ort(grid)
    
    solver = cp_model.CpSolver()
    status = solver.Solve(csp) # or similar?
    
    if status != cp_model.INFEASIBLE:
        return [[solver.Value(board[i][j]) for j in range(9)] for i in range(9)]

sol = solve_sudoku_ort(preds.tolist())
sol

[[1, 2, 3, 4, 5, 6, 7, 8, 9],
 [4, 5, 6, 7, 8, 9, 1, 2, 3],
 [7, 8, 9, 1, 2, 3, 4, 5, 6],
 [2, 1, 4, 3, 6, 5, 8, 9, 7],
 [3, 6, 5, 8, 9, 7, 2, 1, 4],
 [8, 9, 7, 2, 1, 4, 3, 6, 5],
 [5, 3, 1, 6, 4, 2, 9, 7, 8],
 [6, 4, 2, 9, 7, 8, 5, 3, 1],
 [9, 7, 8, 5, 3, 1, 6, 4, 2]]

## Finding the maximum likelihood solution

As errors in the output may lead to infeasible sudoku's, we are going to want to find the _maximum likelihood_ solution.

First, we read and store the prediction probabilities instead of the predictions. We obtain a 9x9x9 tensor (last dimension = probabilities of digit 1..9)


In [None]:
# get probabilities of predictions
def predict_proba_sudoku(model, vizsudoku_idx, dataset):
    nonzero = vizsudoku_idx > -1
    probsudoku = torch.zeros((9,9,9))
    images = trainset.data[vizsudoku_idx[nonzero]]
    images = images.view(images.shape[0], -1).type(torch.float)
    # images.shape (23,784)
    with torch.no_grad():
        probs = model(images).exp()
        predsudoku[nonzero] = probs
    return predsudoku.reshape((9,9,9))

predict_proba_sudoku(model, vizsudoku_idx, trainset)

## Maximum likelihood estimation with standard CP solver

We need to turn the _satisfaction_ problem of sudoku into an _optimisation_ problem, where we optimize for maximum log likelihood.

__Task: adapt the above code to find the maximum likelihood visual sudoku solution!__

This means adding the objective function: a weighted sum of the decision variables, with as weight the log-probability of that decision variable being equal to the corresponding predicted value.

E.g. \sum_i \sum_j \sum_c log(prob[i,j,c])*[V[i,j] == c]

Note that the only thing that changes is adding the objective, so you can reuse model_sudoku_ort() of an empty grid!!

In [None]:
def solve_vizsudoku_ort(probs):
    # the constraint model
    empty_grid = torch.zeros((9,9), dtype=torch.int).tolist()
    csp = model_sudoku_ort(empty_grid)
    
    # TODO: add the objective function!!
    
    solver = cp_model.CpSolver()
    status = solver.solve(csp) # or similar?
        
    if status != None: # todo, actual status check
        return board # or its values

probs = predict_proba_sudoku(model, vizsudoku_idx, trainset)
psol = solve_vizsudoku_ort(probs)
psol

Let's check what the error is now...

# Maxime, can you add visualizers? thx...
You will need to solve the true labels to get the full 'true' solution I think...