In [None]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
import os
from skimage import filters
import glob
import random
import pickle
import tkinter as tk
from tkinter import filedialog

In [None]:
# general config and set up
# trained model file name
fileModel = 'sudoku_clf.kb'
# where to save the squares
squaresDir = 'D:\\repos\\sudokuSolver_py\\latest\\testRun\\'
# where to save the wrapped image
#file name for camera acquired sudoku
myFile = "opencv_frame_0.png"
            

In [None]:
#This will display all the available mouse click events  
#events = [i for i in dir(cv2) if 'EVENT' in i]
#print(events)

#click event function
def click_event(event, x, y, flags, param):
    if event == cv2.EVENT_LBUTTONDOWN:
        refPt.append([x,y])
        cv2.circle(img,(x,y),2,(255, 0, 0),2)
        cv2.imshow("image", img)

def imgProcessing(img):
    img = cv2.resize(img, (30,30), interpolation = cv2.INTER_AREA)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
    img = clahe.apply(img)
    
    img = filters.gaussian(img, sigma=1, multichannel=False)
    img = filters.sobel(img)
    return(img)

def ensureSquare(listIs):
    lensW = listIs[1][0]-listIs[0][0]
    lensH = listIs[2][1]-listIs[1][1]
    diffL = abs(lensW-lensH)
    if lensW>lensH:
        listIs[1][0] = listIs[1][0]-diffL
        listIs[3][0] = listIs[3][0]-diffL
    if lensH>lensW:
        listIs[2][1] = listIs[2][1]-diffL
        listIs[3][1] = listIs[3][1]-diffL
    return(listIs)

def getDst(start):
    dst = [[min(start[0][0], start[2][0]), min(start[0][1], start[1][1])],
           [max(start[1][0], start[3][0]), min(start[0][1], start[1][1])],
          [min(start[0][0], start[2][0]), max(start[2][1], start[3][1])],
          [max(start[1][0], start[3][0]), max(start[2][1], start[3][1])]]
    return(dst)

def formatForWrap(listIs):
    res = np.float32([(listIs[0][0],listIs[0][1]),
                     (listIs[1][0],listIs[1][1]),
                     (listIs[2][0],listIs[2][1]),
                     (listIs[3][0],listIs[3][1])])
    return(res)

def unwarp(img, src, dst):
    h, w = img.shape[:2]
    # use cv2.getPerspectiveTransform() to get M, the transform matrix
    M = cv2.getPerspectiveTransform(src, dst)
    # use cv2.warpPerspective() to warp your image to a top-down view
    warped = cv2.warpPerspective(img, M, (w, h), flags=cv2.INTER_LINEAR)
    return(warped)

def drawGrid(img, coords):
    coords = dstList.copy()
    for i in range(0,len(coords)):
        for ii in range(0,2):
            coords[i][ii] = int(coords[i][ii])
        
    lens = coords[1][0]-coords[0][0]
    cellWidth= int(lens/9)
    colourIs = (0,0,255)

    for i in range(0,10):
        cv2.line(img,(coords[0][0],
                      coords[0][1]+(i*cellWidth)),
                 (coords[0][0]+(9*cellWidth),
                  coords[1][1]+(i*cellWidth)),
                 colourIs,2)
        cv2.line(img,(coords[0][0]+(i*cellWidth),
                      coords[0][1]),
                 (coords[0][0]+(i*cellWidth),
                  coords[1][1]+(9*cellWidth)),
                 colourIs,2)
    plt.imshow(img)
    plt.show()

def saveImg(img, coords, file):
    for i in range(0,len(coords)):
        for ii in range(0,2):
            coords[i][ii] = int(coords[i][ii])
        
    lens = coords[1][0]-coords[0][0]
    cellWidth= int(lens/9)
    colourIs = (0,0,255)

    imCounter = 0
    for i in range(0,9):
        for ii in range(0,9):
            im = img[coords[0][1]+(i*cellWidth):coords[0][1]+((i+1)*cellWidth),
                    coords[0][0]+(ii*cellWidth):coords[0][0]+((ii +1)*cellWidth),:]
            cv2.imwrite('./testRun/{}_{}.png'.format(file, imCounter),im)
            imCounter = imCounter + 1

def rescaleImg(img):
    
    scale_fraction = [1] # fraction of original size

    if img.shape[0]>480:# good for visualization on my laptop
        scale_fraction.append(480/img.shape[0])
    if img.shape[0]>640:# good for visualization on my laptop
        scale_fraction.append(640/img.shape[1])
    scale_fraction = min(scale_fraction)
    width = int(img.shape[1] * scale_fraction)
    height = int(img.shape[0] * scale_fraction)
    dim = (width, height)
    img = cv2.resize(img, dim, interpolation = cv2.INTER_AREA)
    return(img)

def display_grid(grid):
    for line in grid:
        for square in line:
            if square == '0' or square == 0:
                # replace 0 with ".""
                print('.', end = ' ')
            else:
                print(square, end = ' ')
        print()
    print('\n')
    
def display_grid_corrections(grid):
    print(' ',end = ' ')
    print('|',end = ' ')
    print('0',end = ' ')
    print('1',end = ' ')
    print('2',end = ' ')
    print('3',end = ' ')
    print('4',end = ' ')
    print('5',end = ' ')
    print('6',end = ' ')
    print('7',end = ' ')
    print('8',end = ' ')
    print()
    print(22*'-')
    counter = 0
    for line in grid:
        print(counter, end = ' ')
        print('|', end = ' ')
        counter = counter+1
        for square in line:
            if square == '0' or square == 0:
                # replace 0 with ".""
                print('.', end = ' ')
            else:
                print(square, end = ' ')
        print()
    print('\n')
    
def getNewVal():
    newVal = input('New value and its position [number, row, column]: ')
    newVal = newVal.replace(" ", "")
    newVal = newVal.split(',')
    return(newVal)

def getImageLocation():
    # load an existing image
    root = tk.Tk()
    root.withdraw()

    file_path = filedialog.askopenfilename()
    
    strIs = file_path.split('/')
    separator = '/'
    dataDir = separator.join(strIs[0:-1])+'/'
    myFile = strIs[-1]
    return(dataDir, myFile)

def loadModel(fileModel):
    try:
        file = open(fileModel, 'rb')
        clf = pickle.load(file)
        file.close()
    except FileNotFoundError:
        print('Trained model file "{}" not found'.format(fileModel))
        clf = None
    return clf

#clear old data
def clearOldData(squaresDir):
    files = glob.glob(squaresDir+'*')
    for f in files:
        os.remove(f)

def getFrame(myFile):
    cam = cv2.VideoCapture(0)
    cv2.namedWindow("test")

    while True:
        ret, frame = cam.read()
        if not ret:
            print("failed to grab frame")
            break
        maxR=frame.shape[0]
        maxC=frame.shape[1]
        padding = 80
        cent = [maxR/2, maxC/2]
        lensS = (min(maxR,maxC)-padding*2)/2
        # coord top left
        coords = [[int(cent[0]-lensS),int(cent[1]-lensS)],
                 [int(cent[0]-lensS),int(cent[1]+lensS)],
                  [int(cent[0]+lensS),int(cent[1]-lensS)],
                  [int(cent[0]+lensS),int(cent[1]+lensS)]]
        cellWidth = int(lensS*2/9)
     
        colourIs = (0,0,255)
    
        for i in range(0,10):
            cv2.line(frame,(coords[0][1],
                            coords[0][0]+(i*cellWidth)),
                     (coords[0][1]+(9*cellWidth),
                      coords[1][0]+(i*cellWidth)),
                     colourIs,2)
            cv2.line(frame,(coords[0][1]+(i*cellWidth),
                            coords[0][0]),
                     (coords[0][1]+(i*cellWidth),
                      coords[1][0]+(9*cellWidth)),
                     colourIs,2)
    
        cv2.imshow("test", frame)

        k = cv2.waitKey(1)
        if k%256 == 27:
            # ESC pressed
            print("Escape hit, closing...")
            break
        elif k%256 == 32:
            # SPACE pressed
            cv2.imwrite(myFile, frame)
            print("{} written!".format(myFile))
            
    cam.release()

    cv2.destroyAllWindows()
    coords = [[coords[0][1], coords[0][0]], [coords[1][1], coords[1][0]],\
           [coords[2][1], coords[2][0]], [coords[3][1], coords[3][0]]]
    coords = np.asarray(src, dtype=np.float32)
    
    return(coords)

def collectX(squaresDir):
    allFiles = os.listdir(squaresDir)
    X = []
    for i in allFiles:
        img = cv2.imread(squaresDir+i)
        img = imgProcessing(img)
        img = img/255
        X.append(np.ndarray.flatten(img))
    return(X)

def getPredictions(X, clf):
    y_pred = clf.predict(X)
    y_pred = np.array(y_pred).reshape(9,9)
    y_pred = np.ndarray.tolist(y_pred)
    for i in range(0,len(y_pred)):
        y_pred[i] = [int(integer) for integer in y_pred[i]]
    return(y_pred)

def applyCorrections(y_pred):
    newVal = getNewVal()
    while len(newVal) == 3 :   
        y_pred[int(newVal[1])][int(newVal[2])] = int(newVal[0])
        plt.imshow(warped)
        plt.show()
        display_grid_corrections(y_pred)
        newVal = getNewVal()
    return(y_pred)

In [None]:
# solver
# lines and rows only have unique values
def lineCheck(line):
    line = [i for i in line if i > 0]
    res = len(line) == len(np.unique(line))
    return res

# all numbers must be between 0 and 9
def valueCheck(line):
    line = list(line)
    checkVal = min(line) >=0 and max(line) <=9
    return checkVal

# a sub-grid (3X3) only has unique values
def checkSquare(grid):
    grid = np.ndarray.flatten(grid)
    res = lineCheck(grid)
    return res 

# wrapper to call all the the check funcitons
def validateGrid(grid):
    if not grid.shape == (9,9):
        errMsg = 'Invalid grid. Must be 9X9'
        return False, errMsg
    if not grid.dtype == 'int32':
        errMsg = 'Invalid grid. All elements in the grid must be integers.'
        return False, errMsg
    rows = all(np.apply_along_axis(lineCheck, 0, grid))
    cols = all(np.apply_along_axis(lineCheck, 1, grid))
    if( not rows or not cols):
        errMsg = 'Invalid grid. Duplicates in rows or columns.'
        return False, errMsg
    vals = all(np.apply_along_axis(valueCheck, 1, grid))
    if( not vals):
        errMsg = 'Invalid grid. All numbers must be between 0 and 9.'
        return False, errMsg
    gridCoord = [0,3,6]
    checkSquareRes = []
    for i in gridCoord:
        for ii in gridCoord:
            square = grid[i:i+3, ii:ii+3]
            checkSquareRes = np.append(checkSquareRes, checkSquare(square))
            
    
    checkSquareRes = all(checkSquareRes)
    if not checkSquareRes:
        errMsg = 'Invalid grid. Duplicates in a square.'
        return False, errMsg
    
    return True, None

def possible(y, x, n, grid):
    """
    function to check if we can put a number in a (x, y) place of grid
    if we can put the number in the x/y space return True, False otherwise
    """
    # check if the number appears in the row
    for i in range(0,9):
        if grid[y][i] == n:
            return False
    # check if the number appears in the column
    for i in range(0,9):
        if grid[i][x] == n:
            return False
    # check if the number appears in a 3X3 square
    x0 = (x//3)*3
    y0 = (y//3)*3
    for i in range(0,3):
        for j in range(0,3):
            if grid[y0+i][x0+j] == n:
                return False
    # if we can put the number there
    return True

class sudoku:
    def __init__(self, grid):
        isValidGrid, errMsg = validateGrid(np.array(grid))
        if isValidGrid:
            self.grid = grid
        else:
            print(errMsg)
        
    # nice way to print the grid
    def print_grid(self):
        if not hasattr(self, 'grid'):
            print('You need to provide a grid to print first.')
            return
        for line in self.grid:
            for square in line:
                if square == 0:
                    # replace 0 with ".""
                    print('.', end = ' ')
                else:
                    print(square, end = ' ')
            print()
        print('\n')
    
    def solve_puzzle(self):
        if not hasattr(self, 'grid'):
            print('You need to provide a grid to solve first.')
            return
        # loop through rows and columns
        for y in range(0,9):
            for x in range(0,9):
                if self.grid[y][x] == 0:
                    # if the value is 0, we try to put a number
                    # brute force approach by simply try all the numbers
                    for n in range(1,10):
                        if possible(y,x,n,self.grid):
                            # it is possible to put the number: replace the 0 with the new value
                            self.grid[y][x] = n
                            # now we can call solve again, since we have simplyfied the puzzle
                            self.solve_puzzle()
                            #stop the recursive loop and give the first valid answer
                            if np.count_nonzero(np.array(self.grid)) == 81:
                                return self.grid
                            # answer not found. backtracking
                            self.grid[y][x] = 0
                    # we could not solve the puzzle    
                    if x==8 and y==8 and n==9:
                        # we have tried everything
                        print('This puzzle has no solutions.')
                    # we can still backtrack
                    return 
        return

In [None]:
# import trained model obtained fron notebook 04
clf = loadModel(fileModel)

In [None]:
#clear old data
clearOldData(squaresDir)
        
imageType = input('Load existing image [y/n]?: ')
if imageType == 'y':
    dataDir, myFile = getImageLocation()   
    refPt = []
    img = cv2.imread(dataDir+myFile)
    img = rescaleImg(img)
    cv2.imshow("image", img)
    #calling the mouse click event
    cv2.setMouseCallback("image", click_event)
    cv2.waitKey(0)
    cv2.destroyAllWindows()
    src = formatForWrap(refPt)
else:
    src = getFrame(myFile)
    img = cv2.imread(myFile)

dstList = ensureSquare(getDst(src))
dst = formatForWrap(dstList)
warped = unwarp(img, src, dst)
saveImg(warped, dstList, myFile)

plt.imshow(warped, cmap = 'gray')
plt.title(myFile)
plt.show()

In [None]:
# get the prediction
X = collectX(squaresDir)
y_pred = getPredictions(X, clf)
display_grid(y_pred)

In [None]:
# corrections
applyCorrections(y_pred)

In [None]:
obj = sudoku(y_pred)
obj.solve_puzzle()
obj.print_grid()