In [None]:
import cv2
from ultralytics import YOLO
import numpy as np
import chess
import chess.svg
from IPython.display import display, SVG
from collections import Counter
from io import StringIO

def another_model(frame):
    gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
    blurred = cv2.GaussianBlur(gray, (7, 7), 1)
    _, thresh = cv2.threshold(gray, 190, 255, cv2.THRESH_BINARY)
    edges = cv2.Canny(blurred, 50, 150)

    kernel = np.array([[-10, -10, -10],
                       [ 0,  0,  0],
                       [10, 10, 10]])
    op = cv2.filter2D(edges, -1, kernel)

    kernel2 = np.array([[-10, 0, 10],
                        [-10,  0,  10],
                        [-10, 0, 10]])
    op2 = cv2.filter2D(edges, -1, kernel2)

    xor_image = cv2.bitwise_xor(op, op2)
    and_image = cv2.bitwise_or(xor_image, edges)
    inverted_image = cv2.bitwise_not(and_image)
    final = cv2.bitwise_xor(thresh, inverted_image)

    kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5))
    closed_image = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, kernel)

    kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5))
    dilated_image = cv2.dilate(closed_image, kernel, iterations=1)

    contours, _ = cv2.findContours(dilated_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    max_area = 0
    chessboard_contour = None
    for contour in contours:
        # Approximate the contour to a polygon
        epsilon = 0.02 * cv2.arcLength(contour, True)
        approx = cv2.approxPolyDP(contour, epsilon, True)

        # Check if the polygon is a quadrilateral and has the largest area
        if len(approx) == 4 and cv2.contourArea(approx) > max_area:
            chessboard_contour = approx
            max_area = cv2.contourArea(approx)

    # Get the corners
    if chessboard_contour is not None:
        polygon = np.array(chessboard_contour.reshape(4, 2))  # Reshape to 4x2 array
        
    shrink_factor = 0.92

    # Step 1: Find the centroid of the polygon
    centroid = np.mean(polygon, axis=0)

    # Step 2: Scale each point towards the centroid
    shrunken_polygon = []
    for point in polygon:
        # Calculate the vector from the centroid to the point
        vector = point - centroid
        # Scale the vector
        scaled_vector = vector * shrink_factor
        # Get the new point
        new_point = centroid + scaled_vector
        shrunken_polygon.append(new_point)

    shrunken_polygon = np.array(shrunken_polygon, dtype=np.int32)

    return shrunken_polygon
# ----------------------------------------------------------------------------------------------------------------------------


def warping(corners, frame):
    # Convert the frame to grayscale

    # Define the size of the chessboard squares (adjust as needed)
    square_size = 100  # Size of each square in pixels

    # Manually define the corners of the board (for simplicity in this example)
    # This should be replaced with a detection algorithm for the board corners in a real scenario
    board_corners = np.array(corners, dtype="float32")

    # Define the destination points for perspective transformation
    boxes = 8
    dst_pts = np.array([
        [0, 0],
        [boxes * square_size, 0],
        [boxes * square_size, boxes * square_size],
        [0, boxes * square_size]
    ], dtype="float32")

    # Get the homography and warp the perspective
    h_matrix = cv2.getPerspectiveTransform(board_corners, dst_pts)  # Could return this and be done
    warped_perspective = cv2.warpPerspective(frame, h_matrix, (boxes * square_size, boxes * square_size))

    # Calculate the grid coordinates
    # grid = np.zeros((8, 8, 2), dtype=int)
    # for row in range(8):
    #     for col in range(8):
    #         grid[row, col] = [col * square_size, row * square_size]

    # Frame annotation, draw rectangles
    # for row in range(8):
    #     for col in range(8):
    #         top_left_x = col * square_size
    #         top_left_y = row * square_size
    #         bottom_right_x = (col + 1) * square_size
    #         bottom_right_y = (row + 1) * square_size

    #         cv2.rectangle(
    #             warped,
    #             (top_left_x, top_left_y),
    #             (bottom_right_x, bottom_right_y),
    #             (0, 255, 0),
    #             2
    #         )
    return warped_perspective, h_matrix
    # PIECES LOCATIONS ---------------------
    # Input: dictionary from Neno

    # Identify the position in warped
def transform_coords(detections, h_matrix,frame_count,rotated):
    square_size = 100
    for coords, info in detections.items():
        # Transform the position using the homography matrix
        transformed_to_grid_pos = cv2.perspectiveTransform(np.array([[coords]], dtype="float32"), h_matrix)

        # Map to grid
        transformed_x, transformed_y = transformed_to_grid_pos[0][0]

        # Logic to identify which one is row and column and assign alphabets also
        col = int(transformed_x // square_size) + 1
        row = int(transformed_y // square_size) + 1
        mapping_upright = {
                    11: "a8", 12: "b8", 13: "c8", 14: "d8", 15: "e8", 16: "f8", 17: "g8", 18: "h8",
                    21: "a7", 22: "b7", 23: "c7", 24: "d7", 25: "e7", 26: "f7", 27: "g7", 28: "h7",
                    31: "a6", 32: "b6", 33: "c6", 34: "d6", 35: "e6", 36: "f6", 37: "g6", 38: "h6",
                    41: "a5", 42: "b5", 43: "c5", 44: "d5", 45: "e5", 46: "f5", 47: "g5", 48: "h5",
                    51: "a4", 52: "b4", 53: "c4", 54: "d4", 55: "e4", 56: "f4", 57: "g4", 58: "h4",
                    61: "a3", 62: "b3", 63: "c3", 64: "d3", 65: "e3", 66: "f3", 67: "g3", 68: "h3",
                    71: "a2", 72: "b2", 73: "c2", 74: "d2", 75: "e2", 76: "f2", 77: "g2", 78: "h2",
                    81: "a1", 82: "b1", 83: "c1", 84: "d1", 85: "e1", 86: "f1", 87: "g1", 88: "h1"}
        mapping_upsidedown = {
                    11:"h1", 12:"g1", 13:"f1", 14:"e1", 15:"d1", 16:"c1", 17:"b1", 18:"a1",
                    21:"h2", 22:"g2", 23:"f2", 24:"e2", 25:"d2", 26:"c2", 27:"b2", 28:"a2",
                    31:"h3", 32:"g3", 33:"f3", 34:"e3", 35:"d3", 36:"c3", 37:"b3", 38:"a3",
                    41:"h4", 42:"g4", 43:"f4", 44:"e4", 45:"d4", 46:"c4", 47:"b4", 48:"a4",
                    51:"h5", 52:"g5", 53:"f5", 54:"e5", 55:"d5", 56:"c5", 57:"b5", 58:"a5",
                    61:"h6", 62:"g6", 63:"f6", 64:"e6", 65:"d6", 66:"c6", 67:"b6", 68:"a6",
                    71:"h7", 72:"g7", 73:"f7", 74:"e7", 75:"d7", 76:"c7", 77:"b7", 78:"a7",
                    81:"h8", 82:"g8", 83:"f8", 84:"e8", 85:"d8", 86:"c8", 87:"b8", 88:"a8"}
        mapping_90clockwise = {
                    11:"a1", 12:"a2", 13:"a3", 14:"a4", 15:"a5", 16:"a6", 17:"a7", 18:"a8",
                    21:"b1", 22:"b2", 23:"b3", 24:"b4", 25:"b5", 26:"b6", 27:"b7", 28:"b8",
                    31:"c1", 32:"c2", 33:"c3", 34:"c4", 35:"c5", 36:"c6", 37:"c7", 38:"c8",
                    41:"d1", 42:"d2", 43:"d3", 44:"d4", 45:"d5", 46:"d6", 47:"d7", 48:"d8",
                    51:"e1", 52:"e2", 53:"e3", 54:"e4", 55:"e5", 56:"e6", 57:"e7", 58:"e8",
                    61:"f1", 62:"f2", 63:"f3", 64:"f4", 65:"f5", 66:"f6", 67:"f7", 68:"f8",
                    71:"g1", 72:"g2", 73:"g3", 74:"g4", 75:"g5", 76:"g6", 77:"g7", 78:"g8",
                    81:"h1", 82:"h2", 83:"h3", 84:"h4", 85:"h5", 86:"h6", 87:"h7", 88:"h8"}

        mapping_90counterclockwise = {
                    11:"h8", 12:"h7", 13:"h6", 14:"h5", 15:"h4", 16:"h3", 17:"h2", 18:"h1",
                    21:"g8", 22:"g7", 23:"g6", 24:"g5", 25:"g4", 26:"g3", 27:"g2", 28:"g1",
                    31:"f8", 32:"f7", 33:"f6", 34:"f5", 35:"f4", 36:"f3", 37:"f2", 38:"f1",
                    41:"e8", 42:"e7", 43:"e6", 44:"e5", 45:"e4", 46:"e3", 47:"e2", 48:"e1",
                    51:"d8", 52:"d7", 53:"d6", 54:"d5", 55:"d4", 56:"d3", 57:"d2", 58:"d1",
                    61:"c8", 62:"c7", 63:"c6", 64:"c5", 65:"c4", 66:"c3", 67:"c2", 68:"c1",
                    71:"b8", 72:"b7", 73:"b6", 74:"b5", 75:"b4", 76:"b3", 77:"b2", 78:"b1",
                    81:"a8", 82:"a7", 83:"a6", 84:"a5", 85:"a4", 86:"a3", 87:"a2", 88:"a1"}
        
        if rotated != 0 and frame_count >= 20:
            rowcolumn = int(f"{row}{col}")
            if rotated == 1:
                alph = mapping_upright[rowcolumn]
                detections[coords]["grid_pos"] = alph
            elif rotated == 2:
                alph = mapping_upsidedown[rowcolumn]
                detections[coords]["grid_pos"] = alph
            elif rotated == 3:
                alph = mapping_90clockwise[rowcolumn]
                detections[coords]["grid_pos"] = alph
            elif rotated == 4:
                alph = mapping_90counterclockwise[rowcolumn]
                detections[coords]["grid_pos"] = alph
        elif frame_count >= 20:
            detections[coords]["grid_pos"] = int(f"{row}{col}")
           

        # piece_color = info["color"]
        # piece_abbreviation = info["abb"]
        # piece_grid_position = info["grid_pos"]

        # cv2.circle(warped, (int(transformed_x), int(transformed_y)), 10, (0, 0, 255), -1)
        # cv2.putText(warped, f"{piece_abbreviation}, [{piece_grid_position}] : {row,col}",
        #             (int(transformed_x - 20), int(transformed_y + 20)),
        #             cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)
        # print(f"{piece_color}{piece_abbreviation} is located at grid position: {piece_grid_position}, grid coord (r,c) {row,col}")

    return detections

def is_integer(value):
    try:
        # Attempt to cast the value to an integer
        int(value)
        return True
    except (ValueError, TypeError):
        # Return False if it cannot be cast to an integer
        return False

def orientation_flag_from_grid(detections):
    w_king_grid = None
    b_king_grid = None

    # Extract King positions
    for info in detections.values():
        if info["abb"] == "king":
            if info["color"] == "white":
                w_king_grid = info["grid_pos"]
            elif info["color"] == "black":
                b_king_grid = info["grid_pos"]

    if not w_king_grid or not b_king_grid:
        return 0

    # Parse row and column from grid_pos
    print(w_king_grid)
    print(b_king_grid)
    valid = range(0,11)
    if int(str(w_king_grid)[0]) not in valid or int(str(w_king_grid)[1]) not in valid or int(str(b_king_grid)[0]) not in valid or int(str(b_king_grid)[1]) not in valid:
        return 0
    w_row, w_col = int(str(w_king_grid)[0]), int(str(w_king_grid)[1])
    b_row, b_col = int(str(b_king_grid)[0]), int(str(b_king_grid)[1])

    # Check for Upright Orientation
    if 1 <= w_row <= 2 and 7 <= b_row <= 8:
        return 1  # Upright

    # Check for 180-Degree Rotated Orientation
    if 7 <= w_row <= 8 and 1 <= b_row <= 2:
        return 2  # 180-degree rotated

    # Check for 90-Degree Clockwise Rotated Orientation
    if 7 <= w_col <= 8 and 1 <= b_col <= 2:
        return 3  # 90-degree clockwise

    # Check for 90-Degree Counterclockwise Rotated Orientation
    if 1 <= w_col <= 2 and 7 <= b_col <= 8:
        return 4  # 90-degree counterclockwise
    
def detect_moves(detection1, detection2):
    moves = []

    # Create mappings of pieces by (type, color)
    frame1_positions = {
        (info['abb'], info['color']): info['grid_pos']
        for info in detection1.values()
    }
    frame2_positions = {
        (info['abb'], info['color']): info['grid_pos']
        for info in detection2.values()
    }

    # Compare positions for each piece
    for piece, start_pos in frame1_positions.items():
        if piece in frame2_positions:
            end_pos = frame2_positions[piece]
            if start_pos != end_pos:  # Only detect if the position has changed
                moves.append(f"{start_pos} {end_pos}")
    
    return moves

def parse_grid_pos(grid_pos):
    return chess.parse_square(grid_pos)

# -------------------------------------------------------------------------------------------------------------------------------------
# Path to the video file
video_path = 'test_videos/6_Move_student.mp4'
model = YOLO("recog_best.pt")
# Create a VideoCapture object
cap = cv2.VideoCapture(video_path)
frame_count = 0
rotated = 0
prev_abb_count = None
consistency_count = 0

# Initialize board
piece_map = {
    "pawn": chess.PAWN,
    "rook": chess.ROOK,
    "bishop": chess.BISHOP,
    "knight": chess.KNIGHT,
    "queen": chess.QUEEN,
    "king": chess.KING,
}

# Initialize an empty chess board
board = chess.Board(None)  # None initializes an empty board
count = 0
while cap.isOpened():
    # Read a frame
    ret, frame = cap.read()
    
    # Break the loop if there are no frames left
    if not ret:
        print("End of video.")
        break




    
    if frame_count <= 5: # corner recognition to get h matrix 
        corners = another_model(frame)

    warping_result, h_matrix = warping(corners, frame) # do only once


    detections = {}
    results = model(frame)
    for result in results:
        for box in result.boxes:
            # Convert tensor outputs to standard Python types
            class_id = int(box.cls.item())  # Convert class tensor to an integer
            class_name = model.names[class_id]  # Map class ID to name
            confidence = float(box.conf.item())  # Convert confidence tensor to float
            coordinates = box.xyxy[0].tolist()  # Convert coordinates tensor to list

            # Print results in human-readable format
            # print(f"Class: {class_name}, Confidence: {confidence:.2f}, Coordinates: {coordinates}")
            x1, y1, x2, y2 = map(int, coordinates)
            middle_x = (x1 + x2) / 2

            # Calculate 1/4 of height from the top
            quarter_y = y2 + (y1 - y2) / 4

            # Resulting point
            result_point = (middle_x, quarter_y)
            cv2.rectangle(frame, (x1, y1), (x2, y2), (255, 180, 120), 2)  # Draw a green box
            cv2.circle(frame, (int(result_point[0]), int(result_point[1])), 5, (0, 255, 0), -1)
            cv2.putText(frame, f"{class_name} {confidence:.2f}", (x1, y1 - 10),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 180, 120), 2)
            
            colour, piece_name = class_name.split('-')
            detections[result_point] = {
                "abb": piece_name,  # Abbreviation (first letter of class name)
                "color": colour,  # Color based on confidence threshold
                "grid_pos": ""  # Placeholder for grid position
                }
    
    transformed_detections = transform_coords(detections, h_matrix,frame_count,rotated) # think of a logic to assign value accordingly 1->A 2->B based on rotated or not
    abb_counts = Counter(item['abb'] for item in transformed_detections.values())

    if rotated == 0:
        rotated = orientation_flag_from_grid(transformed_detections)
    if frame_count >= 50:
        # Place pieces based on detections dictionary
        for detection in transformed_detections.values():
            piece_type = piece_map[detection["abb"]]  # Get piece type (e.g., chess.PAWN)
            color = chess.BLACK if detection["color"] == "black" else chess.WHITE  # Determine color
            square = parse_grid_pos(detection["grid_pos"])  # Get square index
            piece = chess.Piece(piece_type, color)  # Create a chess piece
            board.set_piece_at(square, piece)  # Place piece on the board
    
    if frame_count % 5 == 0 and frame_count >= 50:
        if count == 1:
            temp = previous_dict
        if count == 5:
            move = detect_moves(previous_dict, transformed_detections)
            if move:
                board.push(chess.Move.from_uci(move))
                count = 0
        

    if cv2.waitKey(1) & 0xFF == ord('q'):
        break
    frame_count += 1
    previous_dict = transformed_detections
    print(transformed_detections)
    prev_abb_count == abb_counts
    svg_board = chess.svg.board(board)
    display(SVG(svg_board))
    print(frame_count)
    cv2.imshow('Video Playback', frame)

    

# Release the video capture object and close display windows
cap.release()
cv2.destroyAllWindows()

game = chess.pgn.Game()
node = game
for move in board.move_stack:
    node = node.add_variation(move)

pgn = StringIO()
print(game, file=pgn)
print("PGN format:")
print(pgn.getvalue())