# Mount Drive

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
!ls "/content/drive/MyDrive/Digital Imaging/final-project/"

2_move_frames	       4_move_frames  8_move_frames  trials-notebooks
2_move_rotated_frames  6_move_frames  runs


In [3]:
# folder of frames from videoes
folder_path = "/content/drive/MyDrive/Digital Imaging/final-project/4_move_frames"

In [4]:
import shutil

def save_folder_to_drive(colab_folder_path, drive_folder_path, mount_drive = False):
    """
    Saves a folder from the Colab runtime to a directory in Google Drive.

    Args:
        colab_folder_path (str): Path to the folder in the Colab runtime.
        drive_folder_path (str): Path to the destination folder in Google Drive.

    Returns:
        str: Message indicating success or error.
    """
    try:
        if (mount_drive):
            drive.mount('/content/drive')

        # Ensure the Google Drive directory exists
        full_drive_path = os.path.join('/content/drive/My Drive', drive_folder_path)
        if not os.path.exists(full_drive_path):
            os.makedirs(full_drive_path)

        # Copy the folder from Colab runtime to Google Drive
        shutil.copytree(colab_folder_path, os.path.join(full_drive_path, os.path.basename(colab_folder_path)))

        return f"Folder successfully saved to {os.path.join(full_drive_path, os.path.basename(colab_folder_path))}"
    except Exception as e:
        return f"An error occurred: {e}"

In [5]:
def get_number_from_file_name(filename: str):
    filename = filename.split(".")[0][::-1]
    res = ""

    for char in filename:
        try:
            res += str(int(char))
        except:
            break

    return int(res[::-1])

# Board Mapping

## Import Libraries

In [6]:
import math
import operator
import sys
from collections import defaultdict

import numpy as np
import cv2

import scipy.spatial as spatial
import scipy.cluster as clstr
import matplotlib.pyplot as plt

import glob
import os

from itertools import combinations

## Functions

In [7]:
def canny(img):
    # Maybe add some auto thresholding here
    edges = cv2.Canny(img, 80, 200)
    return edges


def hough_lines(img):
    rho, theta, thresh = 2, np.pi / 180, 600
    return cv2.HoughLines(img, rho, theta, thresh)


def sort_lines(lines):
    """
    Sorts lines by horizontal and vertical
    """
    h = []
    v = []
    for i in range(lines.shape[0]):
        rho = lines[i][0][0]
        theta = lines[i][0][1]
        if theta < np.pi / 4 or theta > np.pi - np.pi / 4:
            v.append([rho, theta])
        else:
            h.append([rho, theta])
    return h, v


def calculate_intersections(h, v):
    """
    Finds the intersection of two lines given in Hesse normal form.
    See https://stackoverflow.com/a/383527/5087436
    """
    points = []
    for rho1, theta1 in h:
        for rho2, theta2 in v:
            A = np.array([
                [np.cos(theta1), np.sin(theta1)],
                [np.cos(theta2), np.sin(theta2)]
            ])
            b = np.array([[rho1], [rho2]])
            point = np.linalg.solve(A, b)
            point = int(np.round(point[0])), int(np.round(point[1]))
            points.append(point)
    return np.array(points)


def cluster_intersections(points, max_dist=40):
    # I want to change this to kmeans
    Y = spatial.distance.pdist(points)
    Z = clstr.hierarchy.single(Y)
    T = clstr.hierarchy.fcluster(Z, max_dist, 'distance')
    clusters = defaultdict(list)
    for i in range(len(T)):
        clusters[T[i]].append(points[i])
    clusters = clusters.values()
    clusters = map(lambda arr: (np.mean(np.array(arr)[:, 0]), np.mean(np.array(arr)[:, 1])), clusters)

    result = []
    for point in clusters:
        result.append([point[0], point[1]])
    return result


def find_chessboard_corners(points):
    """
    Code from https://medium.com/@neshpatel/solving-sudoku-part-ii-9a7019d196a2
    """
    # Bottom-right point has the largest (x + y) value
    # Top-left has point smallest (x + y) value
    # Bottom-left point has smallest (x - y) value
    # Top-right point has largest (x - y) value
    bottom_right, _ = max(enumerate([pt[0] + pt[1] for pt in points]), key=operator.itemgetter(1))
    top_left, _ = min(enumerate([pt[0] + pt[1] for pt in points]), key=operator.itemgetter(1))
    bottom_left, _ = min(enumerate([pt[0] - pt[1] for pt in points]), key=operator.itemgetter(1))
    top_right, _ = max(enumerate([pt[0] - pt[1] for pt in points]), key=operator.itemgetter(1))
    return [points[top_left], points[top_right], points[bottom_left], points[bottom_right]]


def distance_between(p1, p2):
    """
    Code from https://medium.com/@neshpatel/solving-sudoku-part-ii-9a7019d196a2
    """
    a = p2[0] - p1[0]
    b = p2[1] - p1[1]
    return np.sqrt((a ** 2) + (b ** 2))


def warp_image(img, edges):
    """
    Code from https://medium.com/@neshpatel/solving-sudoku-part-ii-9a7019d196a2
    """
    top_left, top_right, bottom_left, bottom_right = edges[0], edges[1], edges[2], edges[3]

    # Explicitly set the data type to float32 or 'getPerspectiveTransform' will throw an error
    warp_src = np.array([top_left, top_right, bottom_right, bottom_left], dtype='float32')

    side = max([
        distance_between(bottom_right, top_right),
        distance_between(top_left, bottom_left),
        distance_between(bottom_right, bottom_left),
        distance_between(top_left, top_right)
    ])

    # Describe a square with side of the calculated length, this is the new perspective we want to warp to
    warp_dst = np.array([[0, 0], [side - 1, 0], [side - 1, side - 1], [0, side - 1]], dtype='float32')

    # Gets the transformation matrix for skewing the image to fit a square by comparing the 4 before and after points
    m = cv2.getPerspectiveTransform(warp_src, warp_dst)

    # Performs the transformation on the original image
    return cv2.warpPerspective(img, m, (int(side), int(side)))


def cut_chessboard(img, output_path, output_prefix=""):
    side_len = int(img.shape[0] / 8)
    for i in range(8):
        for j in range(8):
            tile = img[i * side_len: (i + 1) * side_len, j * side_len: (j + 1) * side_len]
            cv2.imwrite(output_path + output_prefix + "-" + str(j + i * 8) + ".jpg", tile)


def resize_image(img):
    """
    Resizes image to a maximum width of 800px
    """
    width = img.shape[1]
    if width > 800:
        scale = 800 / width
        return cv2.resize(img, None, fx=scale, fy=scale)
    else:
        return img


def process_chessboard(src_path, output_path, output_prefix="", debug=False, count=0):
    src = cv2.imread(src_path)

    if src is None:
        sys.exit("There is no file with this path!")

    src = resize_image(src)
    src_copy = src.copy()

    # Convert to grayscale
    process = cv2.cvtColor(src, cv2.COLOR_BGR2GRAY)

    if debug:
        plt.imshow(process)
        plt.show()

    # Blur to remove disturbing things
    process = cv2.blur(process, (4, 4))

    if debug:
        plt.imshow(process)
        plt.show()

    # Use Canny Edge Detector https://en.wikipedia.org/wiki/Canny_edge_detector
    process = canny(process)

    if debug:
        plt.imshow(process)
        plt.show()

    # Dilate image (thicker lines)
    process = cv2.dilate(process, np.ones((3, 3), dtype=np.uint8))

    if debug:
        plt.imshow(process)
        plt.show()
    # Use Hough transform to detect lines https://en.wikipedia.org/wiki/Hough_transform
    lines = hough_lines(process)

    # Sort lines by horizontal and vertical
    h, v = sort_lines(lines)

    if debug:
        render_lines(src_copy, h, (0, 255, 0))
        render_lines(src_copy, v, (0, 0, 255))
        plt.imshow(process)
        plt.show()

    if len(h) < 9 or len(v) < 9:
        print("There are not enough horizontal and vertical lines in this image. Try it anyway!")

    # Calculate intersections of the horizontal and vertical lines
    intersections = calculate_intersections(h, v)

    if debug:
        render_intersections(src_copy, intersections, (255, 0, 0), 1)
        plt.imshow(process)
        plt.show()

    # Cluster intersection since there are many
    clustered = cluster_intersections(intersections)

    if debug:
        src_copy = src.copy()
        render_intersections(src_copy, clustered, (255, 0, 0), 5)
        plt.imshow(process)
        plt.show()

    if len(clustered) != 81:
        print("Something is wrong. There are " + str(len(intersections)) + " instead of 81 intersections.")

    # Find outer corners of the chessboard
    corners = find_chessboard_corners(clustered)

    if debug:
        src_copy = src.copy()
        render_intersections(src_copy, corners, (255, 0, 0), 5)
        plt.imshow(process)
        plt.show()

    # Warp and crop image
    dst = warp_image(src, corners)
    try:
        cv2.imwrite(output_path + "crop" + str(count) + ".jpg", dst)
    except:
        print("upload fail")
    plt.show()
    if debug:
        plt.imshow(process)
        plt.show()

def render_lines(img, lines, color):
    for rho, theta in lines:
        a = math.cos(theta)
        b = math.sin(theta)
        x0, y0 = a * rho, b * rho
        pt1 = (int(x0 + 1000 * (-b)), int(y0 + 1000 * a))
        pt2 = (int(x0 - 1000 * (-b)), int(y0 - 1000 * a))
        cv2.line(img, pt1, pt2, color, 1, cv2.LINE_AA)


def render_intersections(img, points, color, size):
    for point in points:
        cv2.circle(img, (int(point[0]), int(point[1])), 2, color, size)

In [8]:
def line_intersections(h_lines, v_lines):
  points = []
  for x,[r_h, t_h] in h_lines:
    row = []
    for y,[r_v, t_v] in v_lines:
      a = np.array([[np.cos(t_h), np.sin(t_h)], [np.cos(t_v), np.sin(t_v)]])
      b = np.array([r_h, r_v])
      inter_point = np.linalg.solve(a, b)
      row.append((int(inter_point[0]),int(inter_point[1])))
    points.append(row)
  return np.array(points)

In [9]:
def rotate_image(image):
    img = image.copy()
    left, right, top, bottom = 0, 0, 0, 0
    table = plot_grid(img)

    for i, row in enumerate(table):
        for j, [x1, y1, x2, y2] in enumerate(row):
            # Validate coordinates
            if x1 < 0 or y1 < 0 or x2 > img.shape[1] or y2 > img.shape[0]:
                # print(f"Skipping invalid crop coordinates: {x1, y1, x2, y2}")
                continue

            crop = img[y1:y2, x1:x2]

            if crop.size == 0:
                # print(f"Empty crop at {i}, {j}")
                continue

            # Apply blur
            crop = cv2.blur(crop, (5, 5))

            # Define color ranges
            lower_black = (0, 0, 0)
            upper_black = (50, 50, 50)
            lower_green = (0, 100, 0)
            upper_green = (100, 255, 100)
            lower_white = (200, 200, 200)
            upper_white = (255, 255, 255)

            # Create masks
            mask_black = cv2.inRange(crop, lower_black, upper_black)
            mask_green = cv2.inRange(crop, lower_green, upper_green)
            mask_white = cv2.inRange(crop, lower_white, upper_white)

            # Exclude background
            mask_background = cv2.bitwise_or(mask_green, mask_white)
            mask_black_filtered = cv2.bitwise_and(mask_black, cv2.bitwise_not(mask_background))

            # Compute histogram
            hist = cv2.calcHist([mask_black_filtered], [0], None, [256], [0, 256])
            total_pixels = np.sum(hist)

            if total_pixels > 0 and (hist[255] / total_pixels >= 0.005):  # Adjust confidence threshold
                if i <= 3:
                    top += 1
                if i > 3:
                    bottom += 1
                if j <= 3:
                    left += 1
                if j > 3:
                    right += 1

    # print("Top, Bottom, Left, Right Counts:", top, bottom, left, right)
    max_value = max(top, bottom, left, right)

    # Rotate based on the detected direction
    if max_value == bottom:
        # print("Do nothing")
        return 0
        # return img
    elif max_value == top:
        # print("180-degree rotation")
        # img = cv2.rotate(image, cv2.ROTATE_180)
        return 1
    elif max_value == left:
        # print("90 degrees counter-clockwise")
        # img = cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE)
        return 2
    elif max_value == right:
        # print("90 degrees clockwise")
        # img = cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE)
        return 3

    return 4

In [10]:
def plot_grid(image, verbose = False):
  def find_best_9_elements(arr, target_distance):
    best_combination = None
    min_score = float('inf')

    for combination in combinations(arr, 9):
        distances = [combination[i+1][0] - combination[i][0] for i in range(8)]
        score = sum(abs(d - target_distance) for d in distances)

        if score < min_score:
            min_score = score
            best_combination = combination

    return best_combination



  def get_lines(lines,length,isHorizon):
    tmp,res, =[],[]
    for rho, theta in lines:
        a = math.cos(theta)
        b = math.sin(theta)
        x0, y0 = a * rho, b * rho
        pts = [int(x0 + 1000 * (-b)), int(y0 + 1000 * a) ,int(x0 - 1000 * (-b)), int(y0 - 1000 * a)]
        if(isHorizon):
          tmp.append([max(pts[1],pts[3]),[rho,theta]])
        else:
          tmp.append([max(pts[0],pts[2]),[rho,theta]])
    tmp.sort()
    mean = length/8
    cur = tmp[0]

    for i in range(len(tmp)):
      if(abs(abs(tmp[i][0]-cur[0])-mean)<=25):
        res.append(cur)
        cur = tmp[i]
      if(i==len(tmp)-1 and len(res)<9):
        res.append(cur)
    if(len(res)<9):
      res = find_best_9_elements(tmp,mean)
    return res

  def locate_board(h,v,r_size,c_size):
    start_point = (v[0],h[0])
    row_length = v[-1]-v[0]
    col_length = h[-1]-h[0]
    return start_point,row_length,col_length



  img = image.copy()
  img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  img = canny(img)
  img = cv2.dilate(img, np.ones((3, 3), dtype=np.uint8),3)
  lines = hough_lines(img)
  h, v = sort_lines(lines)
  h = get_lines(h,image.shape[0],True)
  v = get_lines(v,image.shape[1],False)
  points = line_intersections(h,v)

  table = []
  for i in range(8):
    rr = []
    for j in range(8):
        rr.append([points[i][j][0],points[i][j][1],points[i+1][j+1][0],points[i+1][j+1][1]])
    #   rr.append([points[j][i][0],points[j][i][1],points[j+1][i+1][0],points[j+1][i+1][1]])
    table.append(rr)


  for r in table:
    for pts in r:
      cv2.rectangle(image, (pts[0], pts[1]), (pts[2], pts[3]), (0, 255, 0), 2)

  if (verbose):
    plt.imshow(image)
    plt.show()

  return table

## Map board

In [11]:
!rm -rf "/content/output"
!mkdir "/content/output"

In [12]:
output_path = 'output/'

filenames = os.listdir(folder_path)
data = []

for filename in filenames:
    frame_id = int(filename.split(".")[0].strip("img"))
    data.append((frame_id, filename))
data.sort()


for (cnt, path) in data:
    full_path = os.path.join(folder_path, path)
    process_chessboard(full_path, "output/", "", False, cnt)

  point = int(np.round(point[0])), int(np.round(point[1]))


Something is wrong. There are 323 instead of 81 intersections.
Something is wrong. There are 306 instead of 81 intersections.
Something is wrong. There are 240 instead of 81 intersections.
Something is wrong. There are 256 instead of 81 intersections.
Something is wrong. There are 289 instead of 81 intersections.


In [13]:
cropped_filenames = os.listdir("output/")
cropped_filenames.sort(key=lambda s: get_number_from_file_name(s))
frames_corners = []

for path in cropped_filenames:
    img = cv2.imread(os.path.join("output/", path))
    tb = plot_grid(img, verbose = False)
    frames_corners.append(np.array(tb))

In [14]:
assert len(frames_corners) == len(os.listdir(folder_path))

## Rotate Board (if needed)

In [15]:
flag = rotate_image(cv2.imread(os.path.join("output/", cropped_filenames[0])))
assert(flag != 4)

if (flag != 0):
    print(["-", "180-degree rotation", "90 degrees counter-clockwise", "90 degrees clockwise"][flag])
    frames_corners = []
    for filename in cropped_filenames:
        img = cv2.imread(os.path.join("output/", filename))
        rotate_image(img)
        img = [img,
            cv2.rotate(img, cv2.ROTATE_180),
            cv2.rotate(img, cv2.ROTATE_90_COUNTERCLOCKWISE),
            cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
                ][flag]
        cv2.imwrite(os.path.join("output/", filename), img)
        tb = plot_grid(img, verbose = False)
        frames_corners.append(np.array(tb))


In [16]:
assert len(frames_corners) == len(os.listdir(folder_path))

# YOLO setup

## Ultralytics setup

In [17]:
!pip install ultralytics==8.3.40

Collecting ultralytics==8.3.40
  Downloading ultralytics-8.3.40-py3-none-any.whl.metadata (35 kB)
Collecting ultralytics-thop>=2.0.0 (from ultralytics==8.3.40)
  Downloading ultralytics_thop-2.0.13-py3-none-any.whl.metadata (9.4 kB)
Downloading ultralytics-8.3.40-py3-none-any.whl (898 kB)
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m898.5/898.5 kB[0m [31m24.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading ultralytics_thop-2.0.13-py3-none-any.whl (26 kB)
Installing collected packages: ultralytics-thop, ultralytics
Successfully installed ultralytics-8.3.40 ultralytics-thop-2.0.13


In [18]:
import ultralytics
ultralytics.checks()

Ultralytics 8.3.40 üöÄ Python-3.10.12 torch-2.5.1+cu121 CUDA:0 (Tesla T4, 15102MiB)
Setup complete ‚úÖ (2 CPUs, 12.7 GB RAM, 32.6/235.7 GB disk)


In [19]:
from ultralytics import YOLO
from ultralytics import settings

In [20]:
model = YOLO("yolov8n.pt")

Downloading https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n.pt to 'yolov8n.pt'...


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 6.25M/6.25M [00:00<00:00, 96.2MB/s]


# Detect Piece Inference

In [28]:
!wget https://github.com/prinnnnnn/dig-image-final-project/raw/refs/heads/main/weights/best.pt best.pt

--2024-12-10 11:49:21--  https://github.com/prinnnnnn/dig-image-final-project/raw/refs/heads/main/weights/best.pt
Resolving github.com (github.com)... 140.82.113.3
Connecting to github.com (github.com)|140.82.113.3|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/prinnnnnn/dig-image-final-project/refs/heads/main/weights/best.pt [following]
--2024-12-10 11:49:21--  https://raw.githubusercontent.com/prinnnnnn/dig-image-final-project/refs/heads/main/weights/best.pt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 52040395 (50M) [application/octet-stream]
Saving to: ‚Äòbest.pt‚Äô


2024-12-10 11:49:21 (314 MB/s) - ‚Äòbest.pt‚Äô saved [52040395/52040395]

--2024-12-10 11:49:21--  http://best.pt/
Resol

In [29]:
model = YOLO("best.pt")

In [30]:
uncropped_filenames = os.listdir(folder_path)
uncropped_filenames.sort(key=lambda s: get_number_from_file_name(s))
uncropped_filenames

['img1.jpg', 'img2.jpg', 'img3.jpg', 'img4.jpg', 'img5.jpg']

In [31]:
results = []
frames_classes = []
frames_coors = []
frames_confs = []

# for path in uncropped_filenames:
for path in cropped_filenames:
    img = cv2.imread(os.path.join("output/", path))
    # img = cv2.imread(os.path.join(folder_path, path))
    res = model(img, verbose = False, conf = 0.25)[0]
    classes_labels = np.array(res.boxes.cls.to("cpu"), dtype=int)
    locs = res.boxes.xyxy.to("cpu").numpy()
    confs = res.boxes.conf.to("cpu").numpy()

    # res.show()
    results.append(res.to("cpu"))
    frames_classes.append(classes_labels)
    frames_coors.append(locs)
    frames_confs.append(confs)

In [32]:
cropped_filenames

['crop1.jpg', 'crop2.jpg', 'crop3.jpg', 'crop4.jpg', 'crop5.jpg']

## Map infereneced result to coordinate

In [33]:
assert len(frames_classes) == len(frames_corners)
assert len(frames_coors) == len(frames_corners)
assert len(frames_confs) == len(frames_corners)

In [34]:
label_to_full_names = [
    "white-queen",
    "white-pawn",
    "black-rook",
    "black-bishop",
    "black-knight",
    "black-queen",
    "black-pawn",
    "black-king",
    "white-rook",
    "white-bishop",
    "white-knight",
    "white-king"
]

label_to_piece_names = [
    "Q",
    "P",
    "r",
    "b",
    "n",
    "q",
    "p",
    "k",
    "R",
    "B",
    "N",
    "K"
]

columns = ["h", "g", "f", "e", "d", "c", "b", "a"]
rows = [i+1 for i in range(8)]

In [35]:
encoded_frames = []
confs_frames = []

for i in range(len(frames_corners)): # for each frame

    board = [["." for _ in range(8)] for _ in range(8)]
    piece_conf = [[0 for _ in range(8)] for _ in range(8)]

    corners = frames_corners[i]
    classes = frames_classes[i]
    locs = frames_coors[i]
    confs = frames_confs[i]
    filename = cropped_filenames[i]
    # board_bgr = cv2.imread(f"output/{filename}")
    # board = cv2.cvtColor(board_bgr, cv2.COLOR_BGR2RGB)

    for (j, (cls, (x1, y1, x2, y2), conf)) in enumerate(zip(classes, locs, confs)): # for each piece found
        x1, y1, x2, y2 = round(x1), round(y1), round(x2), round(y2)

        x_mid = (x1 + x2) // 2
        y_base = y2 - (y2 - y1) // 10
        located = False
        # print(f"(x,y) = ({x_mid}, {y_base})")

        for row in range(8):
            if (located):
                break
            for col in range(8):
                if (located):
                    break
                x_l, y_t, x_r, y_b = corners[row][col]

                assert(x_l < x_r)
                assert(y_t < y_b)

                if ((x_l <= x_mid <= x_r) and (y_t <= y_base <= y_b)):
                    # print(f"Found {label_to_piece_names[cls]} at {columns[col]}{rows[row]} with conf: {conf:.4f}")
                    if (piece_conf[row][col] == 0):
                        piece_conf[row][col] = conf
                        board[row][col] = label_to_piece_names[int(results[0].names[cls])]
                        located = True

    nd_confs = np.array(piece_conf)
    nd_board = np.array(board)
    encoded_frames.append(nd_board)
    confs_frames.append(nd_confs)

# Decode board to PGN

In [38]:
board_notation = [
    ['h1', 'g1', 'f1', 'e1', 'd1', 'c1', 'b1', 'a1'],
    ['h2', 'g2', 'f2', 'e2', 'd2', 'c2', 'b2', 'a2'],
    ['h3', 'g3', 'f3', 'e3', 'd3', 'c3', 'b3', 'a3'],
    ['h4', 'g4', 'f4', 'e4', 'd4', 'c4', 'b4', 'a4'],
    ['h5', 'g5', 'f5', 'e5', 'd5', 'c5', 'b5', 'a5'],
    ['h6', 'g6', 'f6', 'e6', 'd6', 'c6', 'b6', 'a6'],
    ['h7', 'g7', 'f7', 'e7', 'd7', 'c7', 'b7', 'a7'],
    ['h8', 'g8', 'f8', 'e8', 'd8', 'c8', 'b8', 'a8']
]

## Functions

In [39]:
def next_moves_knight(board, row, col):
    result = []
    for dx in range(-1, 2):
        for dy in range(-1, 2):
            next_row, next_col = next_row + dy, next_col + dx
            if (next_row >= 0 and next_row < 8 and next_col >= 0 and next_col < 8 and board[next_row][next_col] == '.'):
                result.append(board)

    return result

In [40]:
def check_bishop(start, end, board):
    start_file, start_rank = start
    end_file, end_rank = end
    file_diff = abs(ord(end_file) - ord(start_file))
    rank_diff = abs(int(end_rank) - int(start_rank))

    bn_i = []
    for r in board[::-1]:
      bn_i.append(r[::-1])

    movable = []
    sf1 = ord(start_file)
    sr1 = ord(start_rank)
    pos_file_board1 = sf1 - ord('a')
    pos_rank_board1 =sr1 - ord('1')

    sf2 = ord(start_file)
    sr2 = ord(start_rank)
    pos_file_board2 = sf2 - ord('a')
    pos_rank_board2 =sr2 - ord('1')

    sf3 = ord(start_file)
    sr3 = ord(start_rank)
    pos_file_board3 = sf3 - ord('a')
    pos_rank_board3 =sr3 - ord('1')

    sf4 = ord(start_file)
    sr4 = ord(start_rank)
    pos_file_board4 = sf4 - ord('a')
    pos_rank_board4 =sr4 - ord('1')

    #top right
    while(sf1<= ord('h') and sr1 <= ord('8')):
        if(bn_i[pos_file_board1][pos_rank_board1] == '.'):
            movable.append(chr(sf1)+chr(sr1))
        else:
            break;
        sf1 +=1
        sr1 += 1
        pos_file_board1+=1
        pos_rank_board1+=1

    #bot right
    while(sf2<= ord('h') and sr2 >= ord('1')):
        if(bn_i[pos_file_board2][pos_rank_board2] == '.'):
            movable.append(chr(sf2)+chr(sr2))
        else:
            break;
        sf2 +=  1
        sr2 -= 1
        pos_file_board2+=1
        pos_rank_board2-=1

    #bot left
    while(sf3>= ord('a') and sr3 <= ord('1')):
        if(bn_i[pos_file_board3][pos_rank_board3] == '.'):
            movable.append(chr(sf3)+chr(sr3))
        else:
            break;
        sf3 -=1
        sr3 -= 1
        pos_file_board3 -= 1
        pos_rank_board3 -= 1

      #top left
    while(sf3>= ord('a') and sr3 >= ord('8')):
        if(bn_i[pos_file_board4][pos_rank_board4] == '.'):
            movable.append(chr(sf3)+chr(sr3))
        else:
            break;
        sf3  -=1
        sr3  += 1
        pos_file_board4 -=1
        pos_rank_board4 +=1

    return movable.count(end)

In [41]:
# def check_rook(start, end, board):
#     start_file, start_rank = start
#     end_file, end_rank = end

#     bn_i = []
#     for r in board[::-1]:
#       bn_i.append(r[::-1])

#     movable = []
#     sf1 = ord(start_file)
#     sr1 = ord(start_rank)
#     pos_file_board1 = sf1 - ord('a')
#     pos_rank_board1 =sr1 - ord('1')

#     sf2 = ord(start_file)
#     sr2 = ord(start_rank)
#     pos_file_board2 = sf2 - ord('a')
#     pos_rank_board2 =sr2 - ord('1')

#     sf3 = ord(start_file)
#     sr3 = ord(start_rank)
#     pos_file_board3 = sf3 - ord('a')
#     pos_rank_board3 =sr3 - ord('1')

#     sf4 = ord(start_file)
#     sr4 = ord(start_rank)
#     pos_file_board4 = sf4 - ord('a')
#     pos_rank_board4 =sr4 - ord('1')

#     #top
#     while(sr1 <= ord('8')):
#       if(bn_i[pos_file_board1][pos_rank_board1] == '.'):
#         movable.append(chr(sf1)+chr(sr1))
#       else:
#         break;
#       sr1 += 1
#       pos_rank_board1+=1

#       #bot
#       while(sr2 >= ord('1')):
#         if(bn_i[pos_file_board2][pos_rank_board2] == '.'):
#           movable.append(chr(sf2)+chr(sr2))
#         else:
#           break;
#         sr2 -= 1
#         pos_rank_board2-=1

#         #right
#         while(sf3<= ord('h')):
#           if(bn_i[pos_file_board3][pos_rank_board3] == '.'):
#             movable.append(chr(sf3)+chr(sr3))
#           else:
#             break
#           sr3 += 1
#           pos_file_board3+=1

#         #left
#         while(sf4>= ord('a')):
#           if(bn_i[pos_file_board4][pos_rank_board4] == '.'):
#             movable.append(chr(sf4)+chr(sr4))
#           else:
#             break
#           sr4 -= 1
#           pos_file_board4-=1


#     return movable.count(end)

In [42]:
def get_index(notation, board_notation):
  for row_index, row in enumerate(board_notation):
    for col_index, element in enumerate(row):
      if element == notation:
        return row_index, col_index
  return None

def check_move(board, move):

    def check_pawn(start, end, piece):
        start_file, start_rank = start
        end_file, end_rank = end
        file_diff = abs(ord(end_file) - ord(start_file))
        rank_diff= int(end_rank) - int(start_rank)
        if piece == "P":
            return (start_file == end_file and rank_diff == 1) or (file_diff == 1 and rank_diff == 1) or \
                   (start_file == end_file and rank_diff == 2 and int(start_rank) == 2)
        elif piece == "p":
            return (start_file == end_file and rank_diff == -1) or (file_diff == 1 and rank_diff == -1) or \
                   (start_file == end_file and rank_diff == -2 and int(start_rank) == 7)

    def check_knight(start, end):
        start_file, start_rank = start
        end_file, end_rank = end
        file_diff = abs(ord(end_file) - ord(start_file))
        rank_diff = abs(int(end_rank) - int(start_rank))
        return (file_diff, rank_diff) in [(1, 2), (2, 1)]

    def check_rook(start, end):
        start_file, start_rank = start
        end_file, end_rank = end
        return start_file == end_file or start_rank == end_rank

    def check_bishop(start, end):
        start_file, start_rank = start
        end_file, end_rank = end
        file_diff = abs(ord(end_file) - ord(start_file))
        rank_diff = abs(int(end_rank) - int(start_rank))
        return file_diff == rank_diff

    def check_queen(start, end):
        return check_rook(start, end) or check_bishop(start, end)

    def check_king(start, end):
        start_file, start_rank = start
        end_file, end_rank = end
        file_diff = abs(ord(end_file) - ord(start_file))
        rank_diff = abs(int(end_rank) - int(start_rank))
        return max(file_diff, rank_diff) == 1

    start = move[:2]
    end = move[2:]

    row, col = get_index(start, board_notation)
    piece = board[row][col]

    if piece.lower() == "p":
        return check_pawn(start, end, piece)
    elif piece.lower() == "n":
        return check_knight(start, end)
    elif piece.lower() == "r":
        return check_rook(start, end)
    elif piece.lower() == "b":
        return check_bishop(start, end)
    elif piece.lower() == "q":
        return check_bishop(start, end) or check_rook(start, end)
    elif piece.lower() == "k":
        return check_king(start, end)

def find_move(board_before, board_after):
    moves = []

    for row in range(8):
        for col in range(8):
            piece_before = board_before[row][col]
            piece_after = board_after[row][col]

            if piece_before != piece_after:
                if piece_before != '.':  # A piece moved from this square
                    move_from = board_notation[row][col]
                    next = col
                    for target_row in range(row, 8):
                      for target_col in range(next+1, 8):
                        p_after = board_after[target_row][target_col]
                        if (piece_before == p_after and board_before[target_row][target_col] == '.'):
                          moves.append(move_from+board_notation[target_row][target_col])
                      next = 0

                if piece_after != '.':  # A piece moved to this square
                    move_to = board_notation[row][col]
                    next = col
                    for target_row in range(row, 8):
                      for target_col in range(next+1, 8):
                        p_before = board_before[target_row][target_col]
                        if (piece_after == p_before and board_after[target_row][target_col] == '.'):
                          moves.append(board_notation[target_row][target_col]+move_to)
                      next = 0;

    return moves

In [43]:
def get_turn_start(board_before, move):
  turn_start = ''
  lower_alphabet = 'prbqk'
  upper_alphabet = 'PRBQK'

  start = move[:2]

  start_row, start_col = get_index(start, board_notation)
  piece = board_before[start_row][start_col]

  if piece in lower_alphabet:
    turn_start = 'b'
  elif piece in upper_alphabet:
    turn_start = 'w'

  return turn_start

In [44]:
def uci_to_san(board_before, move):
    start = move[:2]
    end = move[2:]

    # Get the piece at the starting square
    start_row, start_col = get_index(start, board_notation)  # Find the row and column
    piece = board_before[start_row][start_col]  # Get the piece

    end_row, end_col = get_index(end, board_notation)
    captured_piece = board_before[end_row][end_col]

    # Check for captures
    if captured_piece != '.':  # If there's a piece at the destination
        if piece.lower() == 'p':  # If it's a pawn capture
            san = start[0] + 'x' + end  # Add the pawn's file and 'x'
        else:  # If it's another piece capture
            san = piece.upper() + 'x' + end  # Add the piece notation and 'x'
    else:  # If it's not a capture
        if piece.lower() == 'p':
            san = end
        else:
            san = piece.upper() + end

    return san

## Decode

In [45]:
detected_moves = []

for i, (board_before, board_after) in enumerate(zip(encoded_frames[:-1], encoded_frames[1:])):
    # move frame ith to i+1 th
    # print(board_before)
    moves = find_move(board_before, board_after)
    # generate_moves(board_before, position)
    checked_moves = []
    # for m in moves:
        # if (check_move(board_before, m)):
    if (len(moves) > 0):
        checked_moves.append(moves[-1])
    print(f"{i+1}=>{i+2}: {checked_moves}")
    detected_moves.append(checked_moves)

1=>2: ['f6f4']
2=>3: ['h4g6']
3=>4: ['f4g3']
4=>5: []


In [46]:
board_before, board_after

(array([['Q', '.', 'B', 'K', 'Q', '.', '.', 'N'],
        ['P', 'P', 'P', '.', '.', 'P', '.', 'P'],
        ['.', 'p', '.', '.', '.', 'N', '.', '.'],
        ['.', '.', '.', 'P', '.', '.', 'P', '.'],
        ['.', '.', '.', 'p', 'P', '.', '.', 'n'],
        ['.', 'N', '.', '.', 'p', '.', '.', '.'],
        ['.', 'p', '.', '.', 'b', 'p', 'p', 'p'],
        ['r', 'n', 'b', 'k', 'q', '.', '.', 'r']], dtype='<U1'),
 array([['R', '.', 'B', 'Q', 'Q', '.', '.', 'R'],
        ['P', 'P', 'P', '.', '.', 'P', '.', 'P'],
        ['.', 'p', '.', '.', '.', 'N', '.', '.'],
        ['.', '.', '.', 'P', '.', '.', 'P', '.'],
        ['.', '.', '.', 'p', 'P', '.', '.', 'n'],
        ['.', '.', '.', '.', 'p', '.', '.', '.'],
        ['.', 'p', '.', '.', 'b', 'p', 'p', 'p'],
        ['N', 'n', 'b', 'k', 'q', '.', '.', 'r']], dtype='<U1'))

## PGN

In [355]:
detected_moves

[[],
 ['g4e6'],
 ['e8d8'],
 ['e6f7'],
 [],
 ['f7f2'],
 ['c6d5'],
 [],
 [],
 [],
 ['b8c8'],
 [],
 [],
 ['e7c6'],
 ['a7b6']]

In [278]:
for i, moves in enumerate(detected_moves):

    if (len(moves) > 1):
        max_conf_id = 0
        frame_conf = frames_confs[i]

    # transform to PGN


In [453]:
while (len(detected_moves[0]) == 0):
    encoded_frames.pop(0)
    detected_moves.pop(0)

pgn_string = ""
first_move_color = get_turn_start(encoded_frames[0], detected_moves[0][0])
move_number = 1
check = False

san_moves = []
for i, move in enumerate(detected_moves):
    if (len(move) > 0):
        san_moves.append(uci_to_san(encoded_frames[i], move[0]))

for i, move in enumerate(san_moves):
    if first_move_color == "b":
      if i % 2 == 0:
        if not check:
          pgn_string += str(move_number) + ". " + "... " + move + " "
          move_number += 1
          check = True
        else:
          pgn_string += move + " "
          move_number += 1
      else:
        pgn_string += str(move_number) + ". " + move + " "
    else:
      if i % 2 == 0:
        pgn_string += str(move_number) + ". " + move + " "
      else:
        pgn_string += move + " "
        move_number += 1

pgn_string = pgn_string.strip()
print(pgn_string)

1. Bxb5 b6 2. c4 Ne7 3. Rb2
