# Solve Sudoku

In [None]:
TRAINING_IMAGES_FOLDER = "data/train/images/"
TRAINING_LABELS_PATH = "data/train/labels.csv"
TEST_IMAGES_FOLDER = "data/test/images"

### Import Libraries

In [None]:
import os
import tqdm

import pandas as pd
import numpy as np

from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
from sklearn.metrics import mean_squared_error,mean_absolute_error
import matplotlib.pyplot as plt
%matplotlib inline

import cv2 as cv

In [None]:
def cv_imshow(img, ax=None):
    gray = False
    
    if len(img.shape) == 2:
        gray = True
    if img.shape == 3 and img.shape[-1] == 1:
        gray = True
    
    if ax is None:
        _, ax = plt.subplots()

    ax.axis("off")
    if gray:
        ax.imshow(img, cmap="gray")
    else:
        ax.imshow(img[:, :, ::-1])

In [None]:
# ESSE AQUI
def align_img(img):

    # Add border
    img = cv.copyMakeBorder(img, 50, 50, 50, 50, cv.BORDER_CONSTANT, value=(0, 0, 0))

    # Make gray
    gray = cv.cvtColor(img, cv.COLOR_BGR2GRAY)

    # Apply a blur to reduce noise
    blurred = cv.GaussianBlur(gray, (5, 5), 0)

    # Define a 3x3 Sobel filter for detecting horizontal and vertical edges
    sobel_x = np.array([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]])
    sobel_y = np.array([[-1, -2, -1], [0, 0, 0], [1, 2, 1]])
    edges = cv.filter2D(blurred, -1, sobel_x + sobel_y)

    edges = cv.Canny(blurred, 100, 200)


    # Find contours in the edge image
    contours, _ = cv.findContours(edges, cv.RETR_EXTERNAL, cv.CHAIN_APPROX_SIMPLE)

    # Initialize min/max values
    min_x, max_x, min_y, max_y = float('inf'), 0, float('inf'), 0

    # Loop over all contours and their points
    for contour in contours:
        for point in contour:
            x, y = point[0]
            if x < min_x:
                top_left = (x, y)
                min_x = x
            if x > max_x:
                bottom_right = (x, y)
                max_x = x
            if y < min_y:
                top_right = (x, y)
                min_y = y
            if y > max_y:
                bottom_left = (x, y)
                max_y = y

    border = np.array([top_left, bottom_left, bottom_right, top_right])

    # Get the minimum bounding rectangle of the largest contour
    rect = cv.minAreaRect(border)

    # Calculate the rotation angle of the rectangle
    angle = rect[2]
    if angle > 45:
        border = np.array([top_left, bottom_left, bottom_right, top_right])
    else:
        border = np.array([top_right, top_left, bottom_left, bottom_right])

    # # Apply a perspective transform to get a bird's-eye view of the grid
    src = np.float32(border)
    dst = np.float32([[0, 0], [0, 450], [450, 450], [450, 0]])
    M = cv.getPerspectiveTransform(src, dst)
    warped = cv.warpPerspective(gray, M, (450, 450))

    cv_imshow(warped)
    
    return warped

filepath = os.path.join(
        TRAINING_IMAGES_FOLDER,
        f"0000.png"
    )
img = cv.imread(filepath)
warped = align_img(img)

In [None]:
import torch
from classifier.DigitClassifier import DigitClassifier
# Load Digit Classifier

# Instantiate model architecture
model = DigitClassifier()

# Load saved model weights
model.load_state_dict(torch.load('classifier\\digit_classifier.pth'))

# Set model to evaluation mode
model.eval()

In [None]:
# USE THIS

def detect_digit(square):
  square = (square.numpy() * 255).astype(np.uint8).squeeze()
  # print(cell)
  ret, square = cv.threshold(square,150,255,cv.THRESH_BINARY)
  # square = np.invert(square)
  #26x26 square
  square = square[9:19, 9:19] #10x10 square w/out borders
  for i in range(8):
    for j in range(8):
      if (square[i][j]!=0):
        print(square[i][j])
        print(square)
        return True
  return False

In [None]:
import torchvision.transforms as transforms
from PIL import Image

def extract_digits(img):

    transform = transforms.Compose([
    transforms.Grayscale(),
    transforms.Resize((28, 28)),
    transforms.ToTensor()
    ])

    img_digits = []
    digits = []
    for i in range(0, 450, 50):
        for j in range(0, 450, 50):
            img_digit = img[i:i+50, j:j+50]
            # cv_imshow(img_digit)
            img_digit = -transform(Image.fromarray(img_digit)) + 1
            # cv_imshow(img_digit)
            img_digit[img_digit < 0.4] = 0
            #print(img_digit)
            # Apply adaptive thresholding
            # thresh = cv.adaptiveThreshold(np.array(img_digit), 255, cv.ADAPTIVE_THRESH_GAUSSIAN_C, cv.THRESH_BINARY, 11, 2)
            if detect_digit(img_digit):
                img_digits.append(img_digit)
                
                plt.imshow(img_digit.squeeze(), cmap='gray')
                plt.show()
                output = model(img_digit)
                # print(output)
                digit = torch.argmax(output, dim=1).item()
                print(f'digit: {digit}, confidence: {torch.softmax(output, dim=1)}')
                digits.append(digit)
                 
            else:
                img_digits.append(img_digit)
                digits.append(0)
                print('digit doesnt exist')

    return digits, img_digits
digits, img_digits = extract_digits(warped)
# print(digits[0])
# for digit in img_digits:
#     cv_imshow(digit.squeeze().numpy())
#     detect_digit(digit.unsqueeze(0).numpy())
print(digits)