In [None]:
from itertools import combinations
import math
from matplotlib import pyplot as plt
import numpy as np
import cv2
import matplotlib.pyplot as plt
import serial
import time
import cv2
import numpy as np
from scipy.interpolate import RBFInterpolator

class Edge:
    def __init__(self, edge_type, shape, colors):
        self.edge_type = edge_type  # Straight, convex or concave
        self.shape = shape  # edge shape
        self.colors = colors # edge colour points
        self.position = None # top, bottom, left, rigth
        self.piece = None  # which piece edge belongs to

    # Check constraints is edges should be matched
    def can_match_with(self, other_edge):  
        piece1 = self.piece.id + 1
        piece2 = other_edge.piece.id + 1
        edge1 = self.position
        edge2 = other_edge.position

        # Straight edges shouldnt be matched
        if self.edge_type == 'straight' or other_edge.edge_type == 'straight':
            return False

        # Same edge types shouldnt be matched, convex and convex
        if self.edge_type == other_edge.edge_type:
            return False

        # Corner piece constraints
        if self.piece.piece_type == 'corner' or other_edge.piece.piece_type == 'corner':
            # 1. two corners pieces cannot be matched
            # 2. A corner and interior piece cannot be matched
            # 3. An interior piece and corner piece cannot be matched
            if (self.piece.piece_type == 'corner' and other_edge.piece.piece_type == 'corner') or \
            (self.piece.piece_type == 'corner' and other_edge.piece.piece_type == 'interior') or \
            (self.piece.piece_type == 'interior' and other_edge.piece.piece_type == 'corner'):
                return False
            
        # Border piece constraints
        if self.piece.piece_type == 'border' or other_edge.piece.piece_type == 'border':
            # If both pieces are border pieces, they can only be matched on the left and right edge.
            if self.piece.piece_type == 'border' and other_edge.piece.piece_type == 'border':
                if not ((self.position == 'right' and other_edge.position == 'left') or
                        (self.position == 'left' and other_edge.position == 'right')):
                    return False
            
            # Interior pieces cannot match with border pieces left and right edge
            if (self.piece.piece_type == 'border' and self.position in ['left', 'right'] and other_edge.piece.piece_type == 'interior') or \
            (other_edge.piece.piece_type == 'border' and other_edge.position in ['left', 'right'] and self.piece.piece_type == 'interior'):
                return False

            # A bottom border edge can only match with an interior piece
            if self.piece.piece_type == 'border':
                if self.position == 'bottom' and other_edge.piece.piece_type != 'interior':
                    return False
            elif other_edge.piece.piece_type == 'border':
                if other_edge.position == 'bottom' and self.piece.piece_type != 'interior':
                    return False

        return True

class Piece:
    def __init__(self, piece_id, edges, image, corners, original_position=None, crop_values=None):
        self.id = piece_id
        self.edges = edges 
        self.image = image 
        self.corners = corners 
        self.piece_type = self.determine_piece_type()
        self.set_edge_positions()
        self.current_orientation = 0
        for edge in self.edges:
            edge.piece = self
        self.center = self.calculate_center()
        self.original_position = original_position # Bounding box coordinates in the input uncropped image
        self.cropped_dimensions = image.shape[:2]  # (height, width) of boundign box
        self.crop_values = crop_values# 
        self.connections = {0: None, 1: None, 2: None, 3: None} 



    # Piece is a corner with 2 straigth edges, border ofr 1 straight edge, inetrior for no straight edges
    def determine_piece_type(self):
        straight_edges = sum(1 for edge in self.edges if edge.edge_type == 'straight')
        if straight_edges == 2:
            return 'corner'
        elif straight_edges == 1:
            return 'border'
        else:
            return 'interior'


    # Sets edge position names
    # For corner pieces, the straigh edges are always top and right, 
    # For border pieces, the straight edge is top, the edges either side are left and right.
    # For interior pieces, edges are set randomly as there is no staright edge to go off.    
    def set_edge_positions(self):
        if self.piece_type == 'corner':
            straight_indices = [i for i, edge in enumerate(self.edges) if edge.edge_type == 'straight']
            self.edges[straight_indices[0]].position = 'top'
            self.edges[straight_indices[1]].position = 'right'
            non_straight_indices = [i for i in range(4) if i not in straight_indices]
            self.edges[non_straight_indices[0]].position = 'bottom'
            self.edges[non_straight_indices[1]].position = 'left'
        elif self.piece_type == 'border':
            straight_index = next(i for i, edge in enumerate(self.edges) if edge.edge_type == 'straight')
            self.edges[straight_index].position = 'top'
            self.edges[(straight_index+1)%4].position = 'right'
            self.edges[(straight_index+2)%4].position = 'bottom' 
            self.edges[(straight_index+3)%4].position = 'left'
        else:
            for i, pos in enumerate(['top', 'left', 'bottom', 'right']):
                self.edges[i].position = pos

    # returns non straigh edges
    def get_non_straight_edges(self):
        return [i for i, edge in enumerate(self.edges) if edge.edge_type != 'straight']

    # finds piece centre
    def calculate_center(self):
        height, width = self.image.shape[:2]
        return (width // 2, height // 2)

    # Gets corner coordinates
    def get_corner_positions(self):
        return [corner for corner in self.corners]
    
    # Gets corner coordinates in the input uncropped image
    def get_absolute_corner_positions(self):
        if self.original_position is None or self.crop_values is None:
            return self.corners
        
        crop_x, crop_y, _, _ = self.crop_values
        
        absolute_corners = []
        for x, y in self.corners:
            abs_x = crop_x + x
            abs_y = crop_y + y
            
            absolute_corners.append((int(abs_x), int(abs_y)))
        
        return absolute_corners

    # When matcing edge is found, stores the two pieces and their edges
    def connect(self, edge, other_piece, other_edge):
        self.connections[edge] = (other_piece, other_edge)
        other_piece.connections[other_edge] = (self, edge)
    
def show_image(image, title='Image', cmap_type='gray', figsize=(15, 15)):
    plt.figure(figsize=figsize)
    plt.imshow(image, cmap=cmap_type)
    plt.title(title)
    plt.axis('off')
    plt.show()

def show_images(images, titles, figsize=(20, 5)):
    fig, axs = plt.subplots(1, len(images), figsize=figsize)
    for i, (image, title) in enumerate(zip(images, titles)):
        if len(images) > 1:
            ax = axs[i]
        else:
            ax = axs
        if len(image.shape) == 2: 
            ax.imshow(image, cmap='gray')
        else:
            ax.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
        ax.set_title(title)
        ax.axis('off')
    plt.tight_layout()
    plt.show()


def extract_puzzle_pieces(image_path, left=350, top=450, right=1100, bottom=1500):
    # Removing green background
    image = cv2.imread(image_path)
    
    cropped_image = image[top:bottom, left:right]
    
    hsv = cv2.cvtColor(cropped_image, cv2.COLOR_BGR2HSV)
    
    min_hsv = (75, 140, 20)
    max_hsv = (87, 255, 210)
    
    height, width = cropped_image.shape[:2]
    mask = np.zeros((height, width), dtype=np.uint8)
    
    mean_s = []
    for y in range(height):
        for x in range(width):
            h, s, v = hsv[y, x]
            if min_hsv[0] <= h <= max_hsv[0] and min_hsv[2] <= v <= max_hsv[2]:
                mean_s.append(s)
    
    mean_s = sum(mean_s) / len(mean_s) if mean_s else 0
    
    for y in range(height):
        for x in range(width):
            h, s, v = hsv[y, x]
            if min_hsv[0] <= h <= max_hsv[0] and s >= mean_s  and min_hsv[2] <= v <= max_hsv[2]:
                mask[y, x] = 0
            else:
                mask[y, x] = 255
    
    # Cleaning up the piece masks
    kernel = np.ones((3,3), np.uint8)
    mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
    mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)

    median_mask = cv2.medianBlur(mask, 11)

    # Finding cleaned up masks
    contours, _ = cv2.findContours(median_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    
    extracted_pieces_mask = []
    extracted_pieces_colour = []
    bounding_boxes = []
    
    for contour in contours:
        x, y, w, h = cv2.boundingRect(contour)
        if w > 100 and h > 100: # Setting minimum size to get rid if small blobs 
            contour_mask = np.zeros_like(mask)
            cv2.drawContours(contour_mask, [contour], -1, 255, thickness=cv2.FILLED)
            
            piece = cv2.bitwise_and(cropped_image, cropped_image, mask=contour_mask)
            
            extracted_pieces_mask.append(contour_mask)
            extracted_pieces_colour.append(piece)

            # Saving the pieces positions in the input cropped image
            original_x = x + left
            original_y = y + top
            bounding_boxes.append((original_x, original_y, w, h))
    
    
    #show_images([cropped_image, mask, median_mask],
     #           ['Cropped', 'Mask', 'Median Mask'])
    
    #for i, (piece_mask, piece_colour) in enumerate(zip(extracted_pieces_mask, extracted_pieces_colour)):
    #    show_images([piece_mask, piece_colour],
    #                [f'Piece {i+1} Mask', f'Piece {i+1} Colour'])
    
    return extracted_pieces_mask, extracted_pieces_colour, bounding_boxes, (left, top, right, bottom)


def calc_distance(p1, p2):
    return np.sqrt((p1[0] - p2[0])**2 + (p1[1] - p2[1])**2)

def angle_between(p1, p2, p3):
    v1 = np.array(p1) - np.array(p2)
    v2 = np.array(p3) - np.array(p2)
    return np.abs(np.degrees(math.atan2(np.linalg.det([v1,v2]),np.dot(v1,v2))))

def refine_corners(contour, num_corners=4):
    # Convert contour into approx polygon
    epsilon = 0.01 * cv2.arcLength(contour, True)
    approx = cv2.approxPolyDP(contour, epsilon, True)
    best_score = float('inf')
    best_corners = None
    
    # Rank each combo of 4 points
    for corners in combinations(approx[:, 0, :], num_corners):
        angles = [angle_between(corners[i-1], corners[i], corners[(i+1)%num_corners]) for i in range(num_corners)]
        angle_diff = sum(abs(angle - 90) for angle in angles)
        
        distances = [calc_distance(corners[i], corners[j]) for i in range(num_corners) for j in range(i+1, num_corners)]
        distance_variance = np.var(distances)
        
        spread = sum(distances)
        
        score = angle_diff + distance_variance * 0.01 - spread * 0.3
        
        if score < best_score:
            best_score = score
            best_corners = corners
    
    return best_corners

import numpy as np

def improve_corner(contour, corner, num_points=30):
    corner_index = np.where((contour == corner).all(axis=2))[0][0]
    
    # find 4 points
    left_index1 = (corner_index - num_points) % len(contour)
    left_index2 = (corner_index - num_points//2) % len(contour)
    right_index1 = (corner_index + num_points//2) % len(contour)
    right_index2 = (corner_index + num_points) % len(contour)
    
    left_point1 = contour[left_index1][0]
    left_point2 = contour[left_index2][0]
    right_point1 = contour[right_index1][0]
    right_point2 = contour[right_index2][0]
    
    # find slopes
    m1 = (left_point2[1] - left_point1[1]) / (left_point2[0] - left_point1[0]) if left_point2[0] - left_point1[0] != 0 else float('inf')
    m2 = (right_point2[1] - right_point1[1]) / (right_point2[0] - right_point1[0]) if right_point2[0] - right_point1[0] != 0 else float('inf')
    
    # Check if parallel
    if m1 == m2:
        return corner, (left_point1, left_point2, right_point1, right_point2)
    
    # y-intercepts
    b1 = left_point1[1] - m1 * left_point1[0]
    b2 = right_point1[1] - m2 * right_point1[0]
    
    # intersection
    if m1 == float('inf'):
        x = left_point1[0]
        y = m2 * x + b2
    elif m2 == float('inf'):
        x = right_point1[0]
        y = m1 * x + b1
    else:
        x = (b2 - b1) / (m1 - m2)
        y = m1 * x + b1
    
    intersection = np.array([x, y])
    
    # find closest point on contour to the intersection
    distances = np.sqrt(((contour[:, 0, :] - intersection) ** 2).sum(axis=1))
    closest_index = np.argmin(distances)
    refined_corner = contour[closest_index][0]
    
    return refined_corner, (left_point1, left_point2, right_point1, right_point2)

def improve_corners(contour, initial_corners):
    refined_corners = []
    vector_points = []
    for corner in initial_corners:
        refined_corner, points = improve_corner(contour, corner)
        refined_corners.append(refined_corner)
        vector_points.append(points)
    
    return np.array(refined_corners), vector_points

def classify_edge(contour, start, end):
    # when the start point is less then the end point, just use that line.
    # else row stack points to link first and last points
    edge_points = contour[start:end] if start < end else np.vstack((contour[start:], contour[:end]))
    
    # create vector from start to end
    line_vec = edge_points[-1] - edge_points[0]
    line_length = np.linalg.norm(line_vec)
    
    # Find perpenticular distances from line vec
    distances = np.cross(line_vec, edge_points - edge_points[0]) / line_length
    
    max_dist = np.max(np.abs(distances))
    max_dist_signed = distances[np.argmax(np.abs(distances))]
    
    if max_dist < 15:  
        edge_type = "straight"
    elif max_dist_signed > 0:
        edge_type = "convex"
    else:
        edge_type = "concave"
    
    return edge_type

def rotate_points(points, edge_type):
    start = points[0]
    end = points[-1]
    
    # Find angle and matrix needed for line to be on x axis
    angle = np.arctan2(end[1] - start[1], end[0] - start[0])
    
    rotation_matrix = np.array([
        [np.cos(-angle), -np.sin(-angle)],
        [np.sin(-angle), np.cos(-angle)]
    ])
    
    rotated_points = np.dot(points - start, rotation_matrix.T)
    
    # flips concave shapes to match convex shapes
    if edge_type == 'concave':
        rotated_points[:, 1] = -rotated_points[:, 1]
    
    return rotated_points

def sample_edge_shape(contour, start_index, end_index, start_point, end_point, edge_type, num_samples=50):
    # Find evely spaced points on edge
    edge_points = contour[start_index:end_index] if start_index < end_index else np.vstack((contour[start_index:], contour[:end_index]))
    sampled_points = edge_points[np.linspace(0, len(edge_points) - 1, num_samples).astype(int)]
    
    # shift start point to origin
    relative_points = sampled_points.reshape(-1, 2) - start_point
    
    rotated_points = rotate_points(relative_points, edge_type)
    
    return rotated_points

def sample_edge_colors(image, contour, start_index, end_index, num_samples=30, inset_distance=-5):
    # smaple points along edge
    edge_points = contour[start_index:end_index] if start_index < end_index else np.vstack((contour[start_index:], contour[:end_index]))
    sampled_indices = np.linspace(0, len(edge_points) - 1, num_samples).astype(int)
    sampled_points = edge_points[sampled_indices]
    
    # find normal vectors 
    normals = []
    for i in range(len(sampled_points)):
        prev = sampled_points[i-1][0]
        next = sampled_points[(i+1) % len(sampled_points)][0]
        
        vector = next - prev
        normal = np.array([-vector[1], vector[0]])

        normal = normal / np.linalg.norm(normal)
        normals.append(normal)

    # go along each normal the inset distance and sample colour
    colors = []
    inset_points = []
    for point, normal in zip(sampled_points, normals):

        inset_point = point[0] + (normal * inset_distance).astype(int)
        x, y = inset_point.astype(int)
        
        color = image[y,x] 
        
        colors.append(color)
        inset_points.append((x, y))
    
    return np.array(colors), np.array(inset_points)

def analyze_piece_edges(image, mask, num_samples=100,  inset_distance=5):
    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
    contour = max(contours, key=cv2.contourArea)
    
    # get corners
    initial_corners = refine_corners(contour)
    corners, vector_points = improve_corners(contour, initial_corners)
    result = image.copy()
    cv2.drawContours(result, [contour], 0, (0, 255, 0), 2)
    
    edges = []
    edge_shapes = []
    edge_colors = []
    edge_points = []
    for i in range(4):
        # analyze edges using shape and colour
        start_index = np.where((contour == corners[i]).all(axis=2))[0][0]
        end_index = np.where((contour == corners[(i+1)%4]).all(axis=2))[0][0]
        
        edge_type = classify_edge(contour[:, 0, :], start_index, end_index)
        edges.append(edge_type)

        relative_points = sample_edge_shape(contour, start_index, end_index, corners[i], corners[(i+1)%4], edge_type, num_samples)
        edge_shapes.append(relative_points)

        colors, points = sample_edge_colors(image, contour, start_index, end_index,num_samples, inset_distance)
        edge_colors.append(colors)
        edge_points.append(points)
        
        edge_color = {
            "straight": (255, 255, 0),
            "convex": (0, 255, 255),
            "concave": (255, 0, 255)}[edge_type]
        
        edge_contour = contour[start_index:end_index] if start_index < end_index else np.vstack((contour[start_index:], contour[:end_index]))
        cv2.drawContours(result, [edge_contour], 0, edge_color, 2)
            
    return edges, edge_shapes, edge_colors, edge_points, result, initial_corners, corners, vector_points


def smooth_curve(points, w=3):
    # use sliding window approach to smooth edge shape
    smoothed = np.zeros_like(points)
    for i in range(len(points)):
        start = max(0, i - w // 2)
        end = min(len(points), i + w // 2 + 1)
        smoothed[i] = np.mean(points[start:end], axis=0)
    return smoothed

def compare_edge_colors(colors1, colors2, reverse=False):
    # flip second colour sample
    if reverse:
        colors2 = colors2[::-1]
    
    # cut shorted smaple strip
    min_length = min(len(colors1), len(colors2))
    colors1 = colors1[:min_length]
    colors2 = colors2[:min_length]
    
    # ensure withing colour range 
    colors1 = np.uint8(np.clip(colors1, 0, 255))
    colors2 = np.uint8(np.clip(colors2, 0, 255))
    
    # Convert to colour lab for comparison 
    colors1_lab = cv2.cvtColor(colors1.reshape(1, -1, 3), cv2.COLOR_BGR2Lab).reshape(-1, 3)
    colors2_lab = cv2.cvtColor(colors2.reshape(1, -1, 3), cv2.COLOR_BGR2Lab).reshape(-1, 3)
    
    # find colour difference between every pair of colours
    color_diffs = np.sqrt(np.sum((colors1_lab - colors2_lab)**2, axis=1))
    
    return np.mean(color_diffs)

def compare_edges(edge1, edge2):

    shape1 = np.array(edge1.shape)
    shape2 = np.array(edge2.shape)
    color1 = np.array(edge1.colors)
    color2 = np.array(edge2.colors)

    # reverse second edge and align them using the middle point
    shape2 = shape2[::-1]
    
    center_x = (np.max(shape2[:, 0]) + np.min(shape2[:, 0])) / 2
    
    shape2[:, 0] = 2 * center_x - shape2[:, 0]
    
    shape1 = shape1 - np.mean(shape1, axis=0)
    shape2 = shape2 - np.mean(shape2, axis=0)
    
    shape1 = smooth_curve(shape1)
    shape2 = smooth_curve(shape2)

    # find shape coreelation
    correlation_x = np.correlate(shape1[:, 0], shape2[:, 0])[0] / (np.std(shape1[:, 0]) * np.std(shape2[:, 0]) * len(shape1))
    correlation_y = np.correlate(shape1[:, 1], shape2[:, 1])[0] / (np.std(shape1[:, 1]) * np.std(shape2[:, 1]) * len(shape1))
    
    correlation = (correlation_x + correlation_y) / 2
    mse = np.mean(np.sum((shape1 - shape2)**2, axis=1))
    
    color_diff = compare_edge_colors(color1, color2)
    
    return correlation, mse, color_diff


def normalize_score(score, min_score, max_score):
    if max_score == min_score:
        return 0 
    return (score - min_score) / (max_score - min_score)

def compare_all_pieces(pieces):
    comparison_results = {}
    all_correlations = []
    all_mses = []
    all_color_diffs = []
    piece_num = len(pieces)

    #  compare each edge and make susre they can match
    for i in range(piece_num):
        for j in range(piece_num):
            if i != j: 
                valid_comparisons = []
                for edge1 in range(4):
                    for edge2 in range(4):
                        if pieces[i].edges[edge1].can_match_with(pieces[j].edges[edge2]):
                            correlation, mse, color_diff = compare_edges(pieces[i].edges[edge1], pieces[j].edges[edge2])
                            valid_comparisons.append((edge1, edge2, (correlation, mse, color_diff)))

                if valid_comparisons:
                    if i not in comparison_results:
                        comparison_results[i] = {}
                    comparison_results[i][j] = {}
                    
                    for edge1, edge2, (correlation, mse, color_diff) in valid_comparisons:
                        all_correlations.append(correlation)
                        all_mses.append(mse)
                        all_color_diffs.append(color_diff)
                        
                        if edge1 not in comparison_results[i][j]:
                            comparison_results[i][j][edge1] = {}

                        comparison_results[i][j][edge1][edge2] = {
                            'correlation': correlation,
                            'mse': mse,
                            'color_diff': color_diff,
                            'type1': pieces[i].edges[edge1].edge_type,
                            'type2': pieces[j].edges[edge2].edge_type
                        }

    # Normalize scores as they are in different ranges
    min_correlation, max_correlation = min(all_correlations), max(all_correlations)
    min_mse, max_mse = min(all_mses), max(all_mses)
    min_color_diff, max_color_diff = min(all_color_diffs), max(all_color_diffs)

    for i in comparison_results:
        for j in comparison_results[i]:
            for edge1 in comparison_results[i][j]:
                for edge2 in comparison_results[i][j][edge1]:
                    result = comparison_results[i][j][edge1][edge2]

                    norm_correlation = normalize_score(result['correlation'], min_correlation, max_correlation)
                    norm_mse = normalize_score(result['mse'], min_mse, max_mse)
                    norm_color_diff = normalize_score(result['color_diff'], min_color_diff, max_color_diff)

                    combined_score = (1 - norm_correlation) * 0.15 + norm_mse * 3 + norm_color_diff * 0.75

                    result['combined_score'] = combined_score
                    result['norm_correlation'] = norm_correlation
                    result['norm_mse'] = norm_mse
                    result['norm_color_diff'] = norm_color_diff

    return comparison_results

def visualize_match(piece1, piece2, edge1_index, edge2_index, match_info):
    # make figure big enough for pieces side by side
    fig, (ax1, ax2, ax3, ax4) = plt.subplots(4, 1, figsize=(5, 10))

    height = max(piece1.image.shape[0], piece2.image.shape[0])
    width = piece1.image.shape[1] + piece2.image.shape[1] + 50  
    canvas = np.ones((height, width, 3), dtype=np.uint8) * 255

    canvas[:piece1.image.shape[0], :piece1.image.shape[1]] = piece1.image
    canvas[:piece2.image.shape[0], piece1.image.shape[1]+50:] = piece2.image
    # draw a line connecting matching edges
    edge1 = piece1.edges[edge1_index]
    edge2 = piece2.edges[edge2_index]

    start1, end1 = piece1.corners[edge1_index], piece1.corners[(edge1_index+1)%4]
    start2, end2 = piece2.corners[edge2_index], piece2.corners[(edge2_index+1)%4]

    offset_x = piece1.image.shape[1] + 50
    start2 = (int(start2[0] + offset_x), int(start2[1]))
    end2 = (int(end2[0] + offset_x), int(end2[1]))

    mid1 = ((start1[0] + end1[0]) // 2, (start1[1] + end1[1]) // 2)
    mid2 = ((start2[0] + end2[0]) // 2, (start2[1] + end2[1]) // 2)
    cv2.line(canvas, mid1, mid2, (0, 255, 0), 5)

    # Add Data
    ax1.imshow(cv2.cvtColor(canvas, cv2.COLOR_BGR2RGB))
    ax1.set_title("Matching Puzzle Pieces")
    ax1.axis('off')

    ax2.plot(edge1.shape[:, 0], edge1.shape[:, 1], 'b-', label='Original')
    smoothed_edge1 = smooth_curve(edge1.shape)
    ax2.plot(smoothed_edge1[:, 0], smoothed_edge1[:, 1], 'r--', label='Smoothed')
    ax2.set_title(f"Piece {piece1.id+1}, Edge {edge1_index+1}")
    ax2.set_aspect('equal')
    ax2.grid(True)
    ax2.legend()

    edge2_flipped = edge2.shape[::-1].copy()
    center_x = (np.max(edge2_flipped[:, 0]) + np.min(edge2_flipped[:, 0])) / 2
    edge2_flipped[:, 0] = 2 * center_x - edge2_flipped[:, 0]

    smoothed_edge2 = smooth_curve(edge2.shape)
    smoothed_edge2_flipped = smoothed_edge2[::-1].copy()
    center_x = (np.max(smoothed_edge2_flipped[:, 0]) + np.min(smoothed_edge2_flipped[:, 0])) / 2
    smoothed_edge2_flipped[:, 0] = 2 * center_x - smoothed_edge2_flipped[:, 0]

    ax3.plot(edge2_flipped[:, 0], edge2_flipped[:, 1], 'b-', label='Original')
    ax3.plot(smoothed_edge2_flipped[:, 0], smoothed_edge2_flipped[:, 1], 'r--', label='Smoothed')
    ax3.set_title(f"Piece {piece2.id+1}, Edge {edge2_index+1} (Reversed & Flipped)")
    ax3.set_aspect('equal')
    ax3.grid(True)
    ax3.legend()

    num_samples = len(edge1.colors)
    scale_factor = 3  # Adjust this value to increase or decrease width
    color_vis = np.zeros((100, num_samples * scale_factor, 3), dtype=np.uint8)

    for i, (color1, color2) in enumerate(zip(edge1.colors, edge2.colors[::-1])):
        color_vis[0:50, i*scale_factor:(i+1)*scale_factor] = color1
        color_vis[50:100, i*scale_factor:(i+1)*scale_factor] = color2

    ax4.imshow(cv2.cvtColor(color_vis, cv2.COLOR_BGR2RGB))
    ax4.set_title("Color Samples (Top: Piece 1, Bottom: Piece 2 Reversed)")
    ax4.axis('off')

    plt.suptitle(f"Match: Pieces {piece1.id+1} and {piece2.id+1}\n" +
                 f"Combined Score: {match_info['combined_score']:.4f}\n" +
                 f"Norm Correlation: {match_info['norm_correlation']:.4f}, " +
                 f"Norm MSE: {match_info['norm_mse']:.4f}, " +
                 f"Norm Color Diff: {match_info['norm_color_diff']:.4f}")
    
    plt.tight_layout()
    plt.show()

# functionn used to fine tune score weights
def visualize_top_matches(pieces, top_matches):
    for i, match in enumerate(top_matches):
        piece1 = pieces[match['piece1']]
        piece2 = pieces[match['piece2']]
        edge1 = match['edge1']
        edge2 = match['edge2']

        visualize_match(piece1, piece2, edge1, edge2, match)

        print(f"Match {i+1}:")
        print(f"Piece {match['piece1']+1} (Type: {piece1.piece_type}) Edge {match['edge1']+1} ({match['type1']}) and")
        print(f"Piece {match['piece2']+1} (Type: {piece2.piece_type}) Edge {match['edge2']+1} ({match['type2']})")
        print(f"Combined Score: {match['combined_score']:.4f}")
        print(f"Normalized Correlation: {match['norm_correlation']:.4f}")
        print(f"Normalized MSE: {match['norm_mse']:.4f}")
        print(f"Normalized Color Difference: {match['norm_color_diff']:.4f}")
        print()



def find_best_starting_corner(pieces, comparison_results):
    best_corner_match = None
    best_score = float('inf')

    # find corner border match with the best score to start
    for piece1_id, matches in comparison_results.items():
        piece1 = pieces[piece1_id]
        if piece1.piece_type != 'corner':
            continue
        
        for piece2_id, edges in matches.items():
            piece2 = pieces[piece2_id]
            if piece2.piece_type != 'border':
                continue
            
            for edge1, edge2_data in edges.items():
                for edge2, match_data in edge2_data.items():
                    if match_data['combined_score'] < best_score:
                        best_score = match_data['combined_score']
                        best_corner_match = {
                            'piece1': piece1_id,
                            'piece2': piece2_id,
                            'edge1': edge1,
                            'edge2': edge2,
                            **match_data
                        }

    if best_corner_match:
        corner_piece = pieces[best_corner_match['piece1']]
        border_piece = pieces[best_corner_match['piece2']]
        corner_edge = best_corner_match['edge1']
        border_edge = best_corner_match['edge2']

        return corner_piece, (border_piece, corner_edge, border_edge)
    else:
        print("No suitable starting corner found.")
        return None, None

# gets opposite edge for border pieces
def get_opposite_edge(edge):
    return (edge + 2) % 4

    
def assemble_edge(start_piece, start_edge, pieces, comparison_results, all_border_pieces, max_pieces=4):
    # keep track of used pieces, assembles first edge
    global used_pieces
    assembled_pieces = [start_piece]
    used_pieces.add(start_piece.id)
    current_piece = start_piece
    current_edge = start_edge
    
    unused_border_pieces = all_border_pieces - used_pieces
    
    #show_piece(start_piece, f"Starting with Piece {start_piece.id + 1} ({start_piece.piece_type})")

    for _ in range(max_pieces - 1):
        next_piece, next_edge, match = find_next_edge_piece(current_piece, get_opposite_edge(current_edge), pieces, comparison_results, unused_border_pieces)
        if next_piece and next_piece.id not in used_pieces:
            assembled_pieces.append(next_piece)
            used_pieces.add(next_piece.id)
            unused_border_pieces.discard(next_piece.id)
            current_piece.connect(get_opposite_edge(current_edge), next_piece, next_edge)
            visualize_match(current_piece, next_piece, get_opposite_edge(current_edge), next_edge, match)
            current_piece = next_piece
            current_edge = next_edge
            
            # Visualize match
            #visualize_match(current_piece, next_piece, current_edge, next_edge, match)
            
            #show_piece(next_piece, f"Added Piece {next_piece.id + 1} ({next_piece.piece_type})")
            if next_piece.piece_type == 'corner':
                break
            current_piece = next_piece
            current_edge = next_edge
        else:
            break

    print("\nAssembled edge pieces:")
    for piece in assembled_pieces:
        print(f"Piece {piece.id + 1} ({piece.piece_type})", end=" - ")
    print("End")

    return assembled_pieces


def assemble_next_edge(assembled_edge, pieces, comparison_results, all_border_pieces, max_pieces=4, first_corner=None):
    global used_pieces
    # find last used corner piece to start edge
    corner_piece = next((piece for piece in reversed(assembled_edge) if piece.piece_type == 'corner'), None)

    # find last pair of matches and what edge they matched with
    previous_match_edge = None
    if len(assembled_edge) > 1:
        previous_piece = assembled_edge[-2]
        for i, corner_edge in enumerate(corner_piece.edges):
            for j, prev_edge in enumerate(previous_piece.edges):
                if corner_edge.can_match_with(prev_edge) and prev_edge.can_match_with(corner_edge):
                    previous_match_edge = i
                    break
            if previous_match_edge is not None:
                break
    
    # start edge will be the other unused valid edge.
    start_edge = determine_start_edge(corner_piece, previous_match_edge)

    # set starting corner and edge to be matched
    new_edge = [corner_piece]
    current_piece = corner_piece
    current_edge = start_edge
    
    unused_border_pieces = all_border_pieces - used_pieces
    
    for i in range(max_pieces - 1):
        
        # if we are looking for the last match on a border it will be a corner piece
        if i == max_pieces - 2:
            next_piece, next_edge, match = find_next_edge_piece(current_piece, current_edge, pieces, comparison_results, unused_border_pieces, only_corners=True)
        else:
            next_piece, next_edge, match = find_next_edge_piece(current_piece, current_edge, pieces, comparison_results, unused_border_pieces)
        
        if next_piece:
            new_edge.append(next_piece)
            used_pieces.add(next_piece.id)
            unused_border_pieces.discard(next_piece.id)

            current_piece.connect(current_edge, next_piece, next_edge)
            
            current_piece = next_piece
            current_edge = (next_edge + 2) % 4
            if next_piece.piece_type == 'corner':
                break
        else:
            break

    # handles last piece of the border
    if not unused_border_pieces and first_corner and new_edge[-1].piece_type != 'corner':
        last_piece = new_edge[-1]
        
        # Find the bottom edge of the last piece
        # check for which edge left or right is not yet connected
        last_piece_edge = next(
            (i for i, edge in enumerate(last_piece.edges) 
            if edge.position in ['left', 'right'] and last_piece.connections[i] is None),
            None
        )
        
        # Find the left edge of the first corner
        # check bottom or left is unconnected
        first_piece_edge = next(
            (i for i, edge in enumerate(first_corner.edges) 
            if edge.position in ['bottom', 'left'] and first_corner.connections[i] is None),
            None)
        

        last_piece.connect(last_piece_edge, first_corner, first_piece_edge)



    print("\nAssembled new edge pieces:")
    for piece in new_edge:
        print(f"Piece {piece.id + 1} ({piece.piece_type})", end=" - ")
    print("End")

    return new_edge

def find_next_edge_piece(current_piece, current_edge, pieces, comparison_results, unused_border_pieces, only_corners=False):
    potential_matches = []
    
    # find used pieces
    for other_piece_id in range(len(pieces)):
        if other_piece_id in unused_border_pieces:
            other_piece = pieces[other_piece_id]
            
            # used for final piece in an edge
            if only_corners and other_piece.piece_type != 'corner':
                continue
            
            
            # check pieces for both positions, i,j and j,i
            if current_piece.id in comparison_results and other_piece_id in comparison_results[current_piece.id]:
                for edge1, edge2_data in comparison_results[current_piece.id][other_piece_id].items():
                    if edge1 == current_edge:
                        for edge2, match_data in edge2_data.items():
                            potential_matches.append((other_piece, edge2, match_data))

            if other_piece_id in comparison_results and current_piece.id in comparison_results[other_piece_id]:
                for edge2, edge1_data in comparison_results[other_piece_id][current_piece.id].items():
                    for edge1, match_data in edge1_data.items():
                        if edge1 == current_edge:
                            potential_matches.append((other_piece, edge2, match_data))

    
    if potential_matches:
        best_match = min(potential_matches, key=lambda x: x[2]['combined_score'])
        return best_match
    else:
        return None, None, None

def determine_start_edge(corner_piece, previous_match_edge):
    straight_edges = [i for i, edge in enumerate(corner_piece.edges) if edge.edge_type == 'straight']

    top_edge, right_edge = straight_edges
    bottom_edge = (top_edge + 2) % 4
    left_edge = (right_edge + 2) % 4
    
    if previous_match_edge == left_edge:
        return bottom_edge
    elif previous_match_edge == bottom_edge:
        return left_edge


def determine_edge_lengths(first_edge_length):
    # pattern for edges will be either of these
    if first_edge_length == 3:
        return [4, 3, 4]
    elif first_edge_length == 4:
        return [3, 4, 3]

def create_puzzle_grid(border_pieces, puzzle_size):
    grid = np.full(puzzle_size, None, dtype=object)
    

    border_index = 0
    rows, cols = puzzle_size

    # Top row
    for col in range(cols):
        grid[0, col] = border_pieces[border_index]
        border_index += 1

    # Right column
    for row in range(1, rows-1):
        grid[row, -1] = border_pieces[border_index]
        border_index += 1

    # Bottom row 
    for col in range(cols-1, -1, -1):
        grid[-1, col] = border_pieces[border_index]
        border_index += 1

    # Left column
    for row in range(rows-2, 0, -1):
        grid[row, 0] = border_pieces[border_index]
        border_index += 1

    print("Grid after placing border pieces:")
    print_grid(grid)
    return grid


def print_grid(grid):
    for row in range(grid.shape[0]):
        for col in range(grid.shape[1]):
            piece = grid[row, col]
            if piece is None:
                print("None", end="\t")
            elif isinstance(piece, int):
                print(f"Int:{piece}", end="\t")
            elif isinstance(piece, Piece):
                print(f"P{piece.id+1}", end="\t")
            else:
                print(f"Unknown:{type(piece)}", end="\t")
        print()
    print()



def get_edge_match_score(piece1, piece2, edge1_pos, edge2_pos, comparison_results):
    if piece1 is None or piece2 is None:
        return float('inf')

    edge1_index = next((i for i, edge in enumerate(piece1.edges) if edge.position == edge1_pos), None)
    edge2_index = next((i for i, edge in enumerate(piece2.edges) if edge.position == edge2_pos), None)

    if edge1_index is None or edge2_index is None:
        return float('inf')


    edge1_index = (edge1_index - piece1.current_orientation) % 4
    edge2_index = (edge2_index - piece2.current_orientation) % 4

    # check in both directions
    score = float('inf')
    if piece1.id in comparison_results and piece2.id in comparison_results[piece1.id]:
        if edge1_index in comparison_results[piece1.id][piece2.id] and edge2_index in comparison_results[piece1.id][piece2.id][edge1_index]:
            score = comparison_results[piece1.id][piece2.id][edge1_index][edge2_index]['combined_score']
    elif piece2.id in comparison_results and piece1.id in comparison_results[piece2.id]:
        if edge2_index in comparison_results[piece2.id][piece1.id] and edge1_index in comparison_results[piece2.id][piece1.id][edge2_index]:
            score = comparison_results[piece2.id][piece1.id][edge2_index][edge1_index]['combined_score']
    
    print(f"Score for Piece {piece1.id+1} (Edge {edge1_pos}) and Piece {piece2.id+1} (Edge {edge2_pos}): {score}")
    return score



def fill_puzzle_interior(grid, unused_pieces, comparison_results):
    rows, cols = grid.shape
    # debugging
    print("\nStarting interior filling process")
    print("Initial grid state:")
    print_grid(grid)
    print(f"Unused pieces: {[p.id + 1 for p in unused_pieces]}")
    is_first_call = True
    for row in range(1, rows-1):  # Skip first and last row
        for col in range(1, cols-1):  # Skip first and last column
            if grid[row, col] is None:
                print(f"\nAttempting to fill position ({row}, {col})")
                # startig with top left most epmty grid position 
                left_piece = grid[row, col-1]
                above_piece = grid[row-1, col]
                below_piece = grid[row+1, col]  
                
                best_match = None
                best_score = float('inf')
                best_orientation = 0
                
                # check edges and then roate interior piece
                for piece in unused_pieces:
                    for orientation in range(4):
                        top_edge_index = (0 - orientation) % 4
                        left_edge_index = (1 - orientation) % 4
                        bottom_edge_index = (2 - orientation) % 4

                        if is_first_call:
                            above_score = get_edge_match_score(above_piece, piece, 'bottom', piece.edges[top_edge_index].position, comparison_results)
                            left_score = get_edge_match_score(left_piece, piece, 'bottom', piece.edges[left_edge_index].position, comparison_results)
                            below_score = get_edge_match_score(below_piece, piece, 'bottom', piece.edges[bottom_edge_index].position, comparison_results)
                        else:
                            print("Checking for unconnected edge")
                            left_piece_edge = next((i for i, conn in left_piece.connections.items() if conn is None), None)
                            left_piece_edge_position = left_piece.edges[left_piece_edge].position if left_piece_edge is not None else 'bottom'

                            above_score = get_edge_match_score(above_piece, piece, 'bottom', piece.edges[top_edge_index].position, comparison_results)
                            left_score = get_edge_match_score(left_piece, piece, left_piece_edge_position, piece.edges[left_edge_index].position, comparison_results)
                            below_score = get_edge_match_score(below_piece, piece, 'bottom', piece.edges[bottom_edge_index].position, comparison_results)

                        combined_score = left_score + above_score + below_score
                        print(f"  Piece {piece.id+1}, Orientation {orientation}: Left Score = {left_score}, Above Score = {above_score}, Below Score = {below_score}, Combined = {combined_score}")

                        if combined_score < best_score:
                            best_score = combined_score
                            best_match = piece
                            best_orientation = orientation
                
                if best_match:
                    grid[row, col] = best_match
                    unused_pieces.remove(best_match)
                    print(f"Placed Piece {best_match.id + 1} at position ({row}, {col}) with orientation {best_orientation}")
                    
                    # calculate edge indices based on best orientation
                    top_edge_index = (0 - best_orientation) % 4
                    left_edge_index = (1 - best_orientation) % 4
                    bottom_edge_index = (2 - best_orientation) % 4

                    # interior pieces aslways match border pieces bottom edge
                    above_bottom_edge_index = next(i for i, edge in enumerate(above_piece.edges) if edge.position == 'bottom')
                    left_bottom_edge_index = next(i for i, edge in enumerate(left_piece.edges) if edge.position == 'bottom')
                    below_bottom_edge_index = next(i for i, edge in enumerate(below_piece.edges) if edge.position == 'bottom')

                    
                    print(f"Piece {best_match.id + 1} connections before: {best_match.connections}")
                    print(f"Above piece {above_piece.id + 1} connections before: {above_piece.connections}")
                    print(f"Left piece {left_piece.id + 1} connections before: {left_piece.connections}")
                    print(f"Below piece {below_piece.id + 1} connections before: {below_piece.connections}")

                    best_match.connect(top_edge_index, above_piece, above_bottom_edge_index)
                    print(f"top edge index {top_edge_index}")
                    if is_first_call:
                        best_match.connect(left_edge_index, left_piece, left_bottom_edge_index)
                    best_match.connect(bottom_edge_index, below_piece, below_bottom_edge_index)
                    
                    print(f"Connected Piece {best_match.id + 1} (top) with Piece {above_piece.id + 1} (bottom)")
                    print(f"Connected Piece {best_match.id + 1} (left) with Piece {left_piece.id + 1} ({'bottom' if is_first_call else left_piece_edge_position})")
                    print(f"Connected Piece {best_match.id + 1} (bottom) with Piece {below_piece.id + 1} (bottom)")
                    
                    above_bottom_edge_index = next(i for i, edge in enumerate(above_piece.edges) if edge.position == 'bottom')
                    left_bottom_edge_index = next(i for i, edge in enumerate(left_piece.edges) if edge.position == 'bottom')
                    below_bottom_edge_index = next(i for i, edge in enumerate(below_piece.edges) if edge.position == 'bottom')

                    print(f"above edge {above_bottom_edge_index}")
                    print(f"left edge {left_bottom_edge_index}")
                    print(f"below edge {below_bottom_edge_index}")

                    # for the first interior piece it will match with  a border pieces bottom edge
                    above_piece.connect(above_bottom_edge_index, best_match, top_edge_index)                    
                    if is_first_call:
                        left_piece.connect(left_bottom_edge_index, best_match, left_edge_index)
                        print(left_bottom_edge_index)
                    else:
                        left_piece.connect(left_piece_edge, best_match, left_edge_index)
                        print(f"leftttttt{left_piece_edge}")
                    below_piece.connect(below_bottom_edge_index, best_match, bottom_edge_index)                    
                    best_match.current_orientation = best_orientation

                    print(f"Piece {best_match.id + 1} connections after: {best_match.connections}")
                    print(f"Above piece {above_piece.id + 1} connections after: {above_piece.connections}")
                    print(f"Left piece {left_piece.id + 1} connections after: {left_piece.connections}")
                    print(f"Below piece {below_piece.id + 1} connections after: {below_piece.connections}")

                else:
                    print(f"No suitable piece found for position ({row}, {col}). Best score was: {best_score}")
                is_first_call = False

    # once interior if filled, connect final unconnected edges
    unconnected_edges = []
    for row in range(1, rows):
        for col in range(1, cols):
            piece = grid[row, col]
            print(row, col)
            if piece:
                for edge_index, edge in enumerate(piece.edges):
                    if edge.edge_type != 'straight' and piece.connections[edge_index] is None:
                        unconnected_edges.append((piece, edge_index))
                        print(unconnected_edges)
    
    if len(unconnected_edges) >= 2:
        piece1, edge1 = unconnected_edges[-2]
        piece2, edge2 = unconnected_edges[-1]
        piece1.connect(edge1, piece2, edge2)
            
    
    print("\nGrid after filling interior:")
    print_grid(grid)
    print(f"Remaining unused pieces: {[p.id + 1 for p in unused_pieces]}")
    
    return grid


def visualize_all_connections(original_image_path, all_pieces):
    original_image = cv2.imread(original_image_path)
    result_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)

    for piece in all_pieces:
        corners = piece.get_absolute_corner_positions() # corner positions in uncropped image
        for i, corner in enumerate(corners):
            cv2.circle(result_image, corner, 5, (255, 0, 0), -1)
            
            # label edges
            next_corner = corners[(i + 1) % 4]
            edge_center = ((corner[0] + next_corner[0]) // 2, (corner[1] + next_corner[1]) // 2)
            edge_position = piece.edges[i].position
            cv2.putText(result_image, f"{edge_position}", edge_center, cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)

            if piece.connections[i]:
                connected_piece, connected_edge = piece.connections[i]
                connected_corners = connected_piece.get_absolute_corner_positions()

                if isinstance(connected_edge, str):
                    connected_edge_index = next(idx for idx, edge in enumerate(connected_piece.edges) if edge.position == connected_edge)
                else:
                    connected_edge_index = connected_edge

                # draw connection
                cv2.line(result_image, next_corner, connected_corners[connected_edge_index], (0, 255, 255), 2)
                cv2.line(result_image, corner, connected_corners[(connected_edge_index + 1) % 4], (0, 255, 255), 2)


        center = tuple(map(int, np.mean(corners, axis=0)))
        cv2.putText(result_image, f"{piece.id + 1}", center, cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)

    plt.figure(figsize=(20, 20))
    plt.imshow(result_image)
    plt.axis('off')
    plt.title("All Piece Connections with Edge Labels")
    plt.show()


def align_pieces(border_pieces, updated_corners=None):
    first_piece, second_piece = border_pieces

    if updated_corners is None:
        updated_corners = {}

    # get corners from uncropped image
    first_piece_corners = updated_corners.get(first_piece.id, first_piece.get_absolute_corner_positions())
    second_piece_corners = updated_corners.get(second_piece.id, second_piece.get_absolute_corner_positions())

    original_center = tuple(map(int, np.mean(second_piece_corners, axis=0)))

    first_piece.center = tuple(map(int, np.mean(first_piece_corners, axis=0)))

    # debug 
    print(f"Aligning piece {first_piece.id} with piece {second_piece.id}")
    print(f"First piece connections: {first_piece.connections}")
    print(f"Second piece connections: {second_piece.connections}")

    # find connections
    connection_edge1 = None
    connection_edge2 = None
    for edge, connection in first_piece.connections.items():
        if connection and connection[0] == second_piece:
            connection_edge1 = edge
            connection_edge2 = connection[1]
            break


    print(f"Using connection: Piece {first_piece.id} Edge {connection_edge1} - Piece {second_piece.id} Edge {connection_edge2}")

    if isinstance(connection_edge1, str):
        connection_edge1 = first_piece.edges.index(next(edge for edge in first_piece.edges if edge.position == connection_edge1))

    if isinstance(connection_edge2, str):
        connection_edge2 = second_piece.edges.index(next(edge for edge in second_piece.edges if edge.position == connection_edge2))


    # Get the connecting corner positions
    corner1_start = np.array(first_piece_corners[connection_edge1])
    corner1_end = np.array(first_piece_corners[(connection_edge1 + 1) % 4])
    corner2_start = np.array(second_piece_corners[connection_edge2])
    corner2_end = np.array(second_piece_corners[(connection_edge2 + 1) % 4])

    # calculate the angle to rotate the second piece
    vector1 = corner1_end - corner1_start
    vector2 = corner2_end - corner2_start
    angle = np.arctan2(vector1[1], vector1[0]) - np.arctan2(vector2[1], vector2[0])
    

    angle += np.pi

    # rotate the second piece
    rotation_matrix = np.array([
        [np.cos(angle), -np.sin(angle)],
        [np.sin(angle), np.cos(angle)]
    ])
    center = np.mean(second_piece_corners, axis=0)
    rotated_corners = [tuple(map(int, np.dot(rotation_matrix, np.array(corner) - center) + center)) for corner in second_piece_corners]

    # Transform edge positions after roatation
    corner1 = corner1_end
    corner2 = np.array(rotated_corners[connection_edge2])
    translation = corner1 - corner2


    second_piece_corners = [tuple(map(int, np.array(corner) + translation)) for corner in rotated_corners]
    new_center = tuple(map(int, np.mean(second_piece_corners, axis=0)))

    # update the second piece
    second_piece.center = new_center
    second_piece.original_position = tuple(np.array(second_piece.original_position[:2]) + translation) + second_piece.original_position[2:]

    # store updated corner positions
    updated_corners[first_piece.id] = first_piece_corners
    updated_corners[second_piece.id] = second_piece_corners

    print(f"Updated first piece center: {first_piece.center}")
    print(f"Updated second piece corners: {second_piece_corners}")
    print(f"Updated second piece center: {second_piece.center}")

    visualize_aligned_pieces(first_piece, second_piece, connection_edge1, connection_edge2, second_piece_corners, original_center)
    rotation_degrees = np.degrees(angle) % 360

    return updated_corners, original_center, new_center, (360 -rotation_degrees)

def visualize_aligned_pieces(piece1, piece2, edge1, edge2, corners, original_center):
    original_image = cv2.imread(r'C:\Users\steph\OneDrive\Masters\cars_final.jpg')
    result_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)

    corners1 = piece1.get_absolute_corner_positions()
    corners2 = corners
    original_corners2 = piece2.get_absolute_corner_positions()

    print(corners1)
    print(corners2)
    
    for i, corner in enumerate(corners1):
        cv2.putText(result_image, f"1-{i}", (corner[0]+5, corner[1]+5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
    
    for i, (old_corner, new_corner) in enumerate(zip(original_corners2, corners2)):
        cv2.circle(result_image, old_corner, 5, (0, 0, 255), -1)  
        cv2.circle(result_image, new_corner, 5, (255, 0, 0), -1)  
        cv2.putText(result_image, f"2-{i}", (new_corner[0]+5, new_corner[1]+5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
        
        cv2.line(result_image, old_corner, new_corner, (0, 255, 0), 2)  

    cv2.circle(result_image, piece2.center, 8, (255, 255, 255), -1)  
    cv2.circle(result_image, original_center, 8, (0, 0, 0), -1)  

    cv2.line(result_image, corners2[edge2], corners2[(edge2 + 1) % 4], (0, 0, 255), 2)  

    plt.figure(figsize=(20, 20))
    plt.imshow(result_image)
    plt.axis('off')
    plt.title("Aligned Pieces with Rotation and Translation")
    plt.show()



def automatic_alignment(edges):
    updated_corners = {}
    alignment_data = []
    alignment_id = 1


    all_pieces = [piece for edge in edges for piece in edge if piece is not None]
    

    all_pieces = list(dict.fromkeys(all_pieces))

    for i in range(len(all_pieces) - 1):
        current_piece = all_pieces[i]
        next_piece = all_pieces[i + 1]

        border_pieces = [current_piece, next_piece]
        
        try:
            alignment_result = align_pieces(border_pieces, updated_corners)
            
            if alignment_result is None or len(alignment_result) != 4:
                print(f"Warning: Alignment failed for pieces {current_piece.id + 1} and {next_piece.id + 1}")
                print(f"Alignment result: {alignment_result}")
                continue

            updated_corners, original_center, new_center, rotation = alignment_result

            if original_center is not None and new_center is not None and rotation is not None:
                alignment_data.append({
                    'id': alignment_id,
                    'piece_id': next_piece.id + 1,
                    'original_center': original_center,
                    'new_center': new_center,
                    'rotation': rotation
                })
                alignment_id += 1

        
        except Exception as e:
            print(f"Error during alignment of pieces {current_piece.id + 1} and {next_piece.id + 1}: {str(e)}")


    print("\nAlignment Data:")
    for data in alignment_data:
        print(f"Alignment {data['id']} (Piece {data['piece_id']}):")
        print(f"  Original Center: {data['original_center']}")
        print(f"  New Center: {data['new_center']}")
        print(f"  Rotation: {data['rotation']:.2f} degrees")
        print()

    return updated_corners, alignment_data


def generate_alignment_path(alignment_data):
    path = []
    for alignment in alignment_data:
        piece_id = alignment['piece_id']
        original_center = alignment['original_center']
        new_center = alignment['new_center']
        rotation = alignment['rotation']
        
        # Positions above original center false open, true closed
        path.append((piece_id, original_center[0], original_center[1], 100, rotation, False))
        path.append((piece_id, original_center[0], original_center[1], -40,rotation,False))
        path.append((piece_id, original_center[0], original_center[1], -40,rotation,True))
        path.append((piece_id, original_center[0], original_center[1], 100,rotation,True))
        
        # Positions above new center
        path.append((piece_id, new_center[0], new_center[1], 100,rotation,True))
        path.append((piece_id, new_center[0], new_center[1], -40,rotation, True))
        path.append((piece_id, new_center[0], new_center[1], -40,rotation, False))
        path.append((piece_id, new_center[0], new_center[1], 100, rotation, False))
    
    return path

def send_command(ser, command):
    print(f"Sending command: {command}")
    ser.write(f"{command}\n".encode())
    time.sleep(0.1)  # Wait for Arduino to process
    response = ""
    while ser.in_waiting:
        response += ser.readline().decode().strip()
    print(f"Arduino response: {response}")
    return response

def move_arm(ser, x, y, z, gripper):
    gripper_value = "true" if gripper else "false"  # "true" for closed, "false" for open
    command = f"I{x:.2f},{y:.2f},{z:.2f},{90},{gripper_value}"
    response = send_command(ser, command)
    if "OK" in response:
        print(f"Arm moved to ({x:.2f}, {y:.2f}, {z:.2f}) with gripper {'closed' if gripper else 'open'}")
        return True
    else:
        print(f"Failed to move arm. Response: {response}")
        return False

def image_to_arm_coordinates(rbf, x, y, constant_z, scale_percent):
    original_x = int(x / (scale_percent / 100))
    original_y = int(y / (scale_percent / 100))
    arm_xy = rbf([[original_x, original_y]])[0]
    return np.array([arm_xy[0], arm_xy[1], constant_z])

def main(piece_num):
    masks, piece_images, bounding_boxes, crop_values = extract_puzzle_pieces(r'C:\Users\steph\OneDrive\Masters\cars_final.jpg')

    pieces = []
    for i in range(piece_num):
        edge_types, edge_shapes, edge_colors, _, result, _, corners, _ = analyze_piece_edges(piece_images[i], masks[i], num_samples=50, inset_distance=-10)
        edges = [Edge(edge_type, shape, colors) for edge_type, shape, colors in zip(edge_types, edge_shapes, edge_colors)]
        piece = Piece(i, edges, piece_images[i], corners, original_position=bounding_boxes[i], crop_values=crop_values)
        pieces.append(piece)

    comparison_results = compare_all_pieces(pieces)
    global used_pieces
    used_pieces = set()

    all_border_pieces = set(p.id for p in pieces if p.piece_type in ['border', 'corner'])

    best_corner, best_match = find_best_starting_corner(pieces, comparison_results)

    if best_corner and best_match:
        corner_piece, (border_piece, corner_edge, border_edge) = best_corner, best_match
        
        print("\nFirst edge:")
        first_edge = assemble_edge(corner_piece, get_opposite_edge(corner_edge), pieces, comparison_results, all_border_pieces)
        first_corner = first_edge[0] 
        remaining_edge_lengths = determine_edge_lengths(len(first_edge))
        
        second_edge = assemble_next_edge(first_edge, pieces, comparison_results, all_border_pieces, max_pieces=remaining_edge_lengths[0], first_corner=first_corner)
        third_edge = assemble_next_edge(second_edge, pieces, comparison_results, all_border_pieces, max_pieces=remaining_edge_lengths[1], first_corner=first_corner)
        fourth_edge = assemble_next_edge(third_edge, pieces, comparison_results, all_border_pieces, max_pieces=remaining_edge_lengths[2], first_corner=first_corner)

        border_pieces = first_edge + second_edge[1:] + third_edge[1:] + fourth_edge[1:]
        interior_pieces = [piece for piece in pieces if piece.id not in [p.id for p in border_pieces]]

        puzzle_size = (3, 4)  
        grid = create_puzzle_grid(border_pieces, puzzle_size)

        all_pieces = set(pieces)
        border_piece_set = set(border_pieces)
        interior_pieces = list(all_pieces - border_piece_set)

        filled_grid = fill_puzzle_interior(grid, interior_pieces, comparison_results)

        
        edges = [first_edge, second_edge, third_edge, fourth_edge]
        edges = [edge for edge in edges if edge is not None]
        edges = [[piece for piece in edge if piece is not None] for edge in edges]

        print("Edges before alignment:")
        for i, edge in enumerate(edges):
            print(f"Edge {i + 1}: {[piece.id + 1 for piece in edge]}")


        all_pieces = [piece for edge in edges for piece in edge]

        all_pieces = list(dict.fromkeys(all_pieces))

        interior_pieces = [piece for piece in pieces if piece not in all_pieces]

        all_pieces.extend(interior_pieces)


        visualize_all_connections(r'C:\Users\steph\OneDrive\Masters\cars_final.jpg', pieces)

        updated_corners, alignment_data = automatic_alignment([all_pieces])

        path = generate_alignment_path(alignment_data)

        for step in path:
            print(f"Piece {step[0]}: (x: {step[1]}, y: {step[2]}, z: {step[3]})")

        port = 'COM3'  
        baud_rate = 115200
        timeout = 5
        constant_z = -50 
        scale_percent = 80 
        calibration_points = [
            (622, 1272, 200.00, 200.00),
            (548, 614, 200.00, -200.00),
            (882, 1398, 300.00, 300.00),
            (774, 442, 300.00, -300.00),
            (394, 1362, 100.00, 200.00),
            (310, 546, 100.00, -200.00),
            (542, 1072, 200.00, 100.00),
            (508, 826, 200.00, -100.00),
            (744, 1008, 300.00, 100.00),
            (718, 850, 300.00, -100.00),
            (806, 834, 250.00, 250.00),
            (660, 520, 250.00, -250.00),
            (888, 364, 350.00, 350.00),
            (886, 308, 350.00, -350.00),
            (528, 1470, 150.00, 250.00),
            (428, 478, 150.00, -250.00),
            (872, 1136, 350.00, 150.00),
            (820, 720, 350.00, -150.00),
            (882, 1334, 350.00, 250.00),
            (854, 556, 350.00, -250.00),
            (818, 1414, 275.00, 275.00),
            (724, 464, 275.00, -275.00),
            (914, 1532, 325.00, 325.00),
            (836, 380, 325.00, -325.00),
            (1016, 1598, 375.00, 375.00),
            (952, 326, 375.00, -375.00),
            (754, 1430, 240.00, 260.00),
            (622, 452, 240.00, -260.00),
            (850, 1500, 280.00, 320.00),
            (760, 394, 280.00, -320.00),
            (884, 1402, 320.00, 280.00),
            (822, 460, 320.00, -280.00),
            (956, 1272, 360.00, 240.00),
            (908, 572, 360.00, -240.00),
            (1010, 1398, 380.00, 300.00),
            (918, 456, 380.00, -300.00),
            (1012, 1188, 400.00, 200.00),
            (962, 666, 400.00, -200.00),
            (596, 880, 250.00, 0.00),
            (690, 880, 300.00, 0.00),
            (816, 864, 350.00, 0.00),
            (818, 862, 175.00, 0.00),
            (542, 914, 225.00, 0.00)
        ]


        image_coords = np.array([(p[0], p[1]) for p in calibration_points])
        arm_coords = np.array([(p[2], p[3]) for p in calibration_points])

        rbf = RBFInterpolator(image_coords, arm_coords, kernel='thin_plate_spline')

        ser = serial.Serial(port, baud_rate, timeout=timeout)
        time.sleep(2)  


        send_command(ser, "1")
        try:
   
            for step in path:
                piece_id, x, y, z, rotation, gripper_closed = step
                
                arm_coords = image_to_arm_coordinates(rbf, x, y, z, scale_percent)

                success = move_arm(ser, arm_coords[0], arm_coords[1], arm_coords[2], gripper_closed)
                time.sleep(2)
                if not success:
                    print(f"Failed to move arm for piece {piece_id}. Skipping to next step.")
                    continue
                
                print(f"Completed step for piece {piece_id}")

            print("Puzzle assembly complete")

        except Exception as e:
            print(f"An error occurred: {e}")
        finally:
            if ser.is_open:
                ser.close()
                print("Serial connection closed.")

if __name__ == '__main__':
    main(12)

Image to IK calibration code

In [None]:
import serial
import time
import cv2
import numpy as np

port = 'COM3'  
baud_rate = 115200
timeout = 5

calibration_points = []
current_arm_position = None
constant_z = -50  # using a constant z value as all pieces have the same height and allows for a 2D to 2D mapping

# Predefined arm coordinates
arm_coordinates = [
(200,200), (200,-200), (300,300), (300,-300), (100,200), (100,-200),
(200,100), (200,-100), (300,100), (300,-100),
(250,250), (250,-250), (350,350), (350,-350), (150,250), (150,-250),
(350,150), (350,-150), (350,250), (350,-250),
(275,275), (275,-275), (325,325), (325,-325), (375,375), (375,-375),
(240,260), (240,-260), (280,320), (280,-320), (320,280), (320,-280),
(360,240), (360,-240), (380,300), (380,-300), (400,200), (400,-200),
(250,0), (300,0), (350,0), (175,0), (225,0)
]

ser = serial.Serial(port, baud_rate, timeout=timeout)
time.sleep(2) 

def send_command(command):
    print(f"Sending command: {command}")
    ser.write(f"{command}\n".encode())
    time.sleep(0.1)  # Wait for Arduino to process
    response = ""
    while ser.in_waiting:
        response += ser.readline().decode().strip()
    print(f"Arduino response: {response}")
    return response

def move_arm(x, y, z, gripper=True):
    gripper_state = "true" if gripper else "false"
    command = f"I{x:.2f},{y:.2f},{z:.2f},{90},{gripper_state}"
    response = send_command(command)

    print(f"Arm moved to ({x:.2f}, {y:.2f}, {z:.2f}) with gripper {'closed' if gripper else 'open'}")
    return True


def resize_frame(frame, scale_percent):
    width = int(frame.shape[1] * scale_percent / 100)
    height = int(frame.shape[0] * scale_percent / 100)
    dim = (width, height)
    return cv2.resize(frame, dim, interpolation=cv2.INTER_AREA)

def mouse_callback(event, x, y, flags, param):
    global current_arm_position, scale_percent
    if event == cv2.EVENT_LBUTTONDOWN:
        if current_arm_position is not None:
            original_x = int(x / (scale_percent / 100))
            original_y = int(y / (scale_percent / 100))
            print(f"Marked image coordinates: ({original_x}, {original_y})")
            arm_x, arm_y, _ = current_arm_position
            calibration_points.append((original_x, original_y, arm_x, arm_y))
            print(f"Calibration point added: Image({original_x}, {original_y}) -> Arm({arm_x:.2f}, {arm_y:.2f})")
            current_arm_position = None

image_path = r"C:\Users\steph\Downloads\calibration_image.jpg"  
original_image = cv2.imread(image_path)
if original_image is None:
    raise ValueError("Could not load the image. Check the file path.")

scale_percent = 50 
image = resize_frame(original_image, scale_percent)

cv2.namedWindow("Calibration Image")
cv2.setMouseCallback("Calibration Image", mouse_callback)

current_point_index = 0

try:
    print("Click on the image to mark the current arm position.")
    print("Press 'm' to move to the next calibration point.")
    print("Press 'r' to reset the calibration.")
    print("Press 'q' to quit.")

    while True:
        display_image = image.copy()

        for point in calibration_points:
            cv2.circle(display_image, (int(point[0] * scale_percent / 100), int(point[1] * scale_percent / 100)), 5, (0, 255, 0), -1)

        cv2.imshow("Calibration Image", display_image)

        key = cv2.waitKey(1) & 0xFF
        if key == ord('q'):
            break
        elif key == ord('m'):
            if current_point_index < len(arm_coordinates):
                x, y = arm_coordinates[current_point_index]
                if move_arm(x, y, constant_z):
                    current_arm_position = (x, y, constant_z)
                    print(f"Moved to point {current_point_index + 1}. Click on the image where the arm is positioned.")
                current_point_index += 1
            else:
                print("All points have been calibrated.")
        elif key == ord('r'):
            current_point_index = 0
            calibration_points.clear()
            print("Reset calibration. Press 'm' to start over.")

except KeyboardInterrupt:
    print("\nExiting...")
except Exception as e:
    print(f"An error occurred: {e}")
finally:
    if ser.is_open:
        ser.close()
        print("Serial connection closed.")
    cv2.destroyAllWindows()


    if calibration_points:
        with open("calibration_data.txt", "w") as f:
            for point in calibration_points:
                f.write(f"{point[0]},{point[1]},{point[2]:.2f},{point[3]:.2f}\n")
        print("Calibration data saved to calibration_data.txt")

Adaptive Thresholding Function

In [None]:
def find_pieces(image_path, piece_number):
    original_image = cv2.imread(image_path)
    grey_image =  cv2.cvtColor(original_image, cv2.COLOR_BGR2GRAY)
    threshold_image = cv2.adaptiveThreshold(grey_image,  255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 11, 5)
    blurred_image = cv2.GaussianBlur(threshold_image, (7,7), 1.5)
    contours, _ = cv2.findContours(blurred_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    sorted_contours = sorted(contours, key=cv2.contourArea, reverse= True)[:piece_number]
    mask = np.zeros(grey_image.shape, dtype= np.uint8)
    cv2.drawContours(mask, sorted_contours, -1, 255, thickness=cv2.FILLED)

    extracted_pieces_colour = []
    extracted_pieces_mask = []

    for contour in sorted_contours:
        contour_mask = np.zeros_like(mask)
        cv2.drawContours(contour_mask, [contour], -1, 255, thickness=cv2.FILLED)

        piece = cv2.bitwise_and(original_image, original_image, mask=contour_mask)
        extracted_pieces_colour.append(piece)
        extracted_pieces_mask.append(contour_mask)
        #show_image(piece)
        #show_image(contour_mask)

    return extracted_pieces_mask, extracted_pieces_colour