In [64]:
import cv2
import math
import numpy as np
from typing import List
from collections import deque

In [65]:
img1 = cv2.imread("model1.png")
mask1 = cv2.imread("model1_class2.png", cv2.IMREAD_GRAYSCALE)
w, h = img1.shape[:2]

img2 = cv2.imread("model2.jpg")
mask2 = cv2.imread("model2_class2.png", cv2.IMREAD_GRAYSCALE)
img2 = cv2.resize(img2, (h, w))
mask2 = cv2.resize(mask2, (h, w), interpolation=cv2.INTER_LINEAR)

In [66]:
contours1, _ = cv2.findContours(mask1, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
contours2, _ = cv2.findContours(mask2, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)

In [67]:
for contour in contours1:
    epsilon = cv2.arcLength(contour, True) * 0.001
    approx_poly1 = cv2.approxPolyDP(contour, epsilon, True)

for contour in contours2:
    epsilon = cv2.arcLength(contour, True) * 0.001
    approx_poly2 = cv2.approxPolyDP(contour, epsilon, True)

In [68]:
import random

approx_poly1 = random.sample(list(approx_poly1), 50)
approx_poly1 = np.array(approx_poly1)

approx_poly2 = random.sample(list(approx_poly2), 50)
approx_poly2 = np.array(approx_poly2)

In [69]:
sorted_approx_poly1 = sorted(approx_poly1.squeeze(axis=1).tolist(), key=lambda x: x[0])
sorted_approx_poly2 = sorted(approx_poly2.squeeze(axis=1).tolist(), key=lambda x: x[0])

In [70]:
def connect_shortest_pt(src):
    start_pt = src[0]
    answer = [start_pt]
    queue = deque(src)
    while queue:
        cx, cy = queue.popleft()
        dist = 1e9
        idx = 0

        for i in range(len(queue)):
            px, py = queue[i]
            if dist > math.sqrt((px-cx)**2 + (py-cy)**2):
                dist = math.sqrt((px-cx)**2 + (py-cy)**2)
                idx = i

        if queue:
            number = queue[idx]
            queue.remove(number)
            answer.append(number)
            queue.appendleft(number)
        else:
            answer.append(number)

    return answer

In [71]:
real_sorted_approx_poly1 = connect_shortest_pt(sorted_approx_poly1)
real_sorted_approx_poly2 = connect_shortest_pt(sorted_approx_poly2)

In [72]:
h, w = img1.shape[:2]
test_img1 = np.zeros((h, w), dtype=np.uint8)

n_pts1 = np.array(real_sorted_approx_poly1, np.int32)
n_pts1 = n_pts1.reshape(-1, 1, 2)

img1_mask = cv2.fillPoly(test_img1, [n_pts1], color=255)
cv2.imwrite("test1.png", img1_mask)

True

In [73]:
h, w = img2.shape[:2]
test_img2 = np.zeros((h, w), dtype=np.uint8)

n_pts2 = np.array(real_sorted_approx_poly2, np.int32)
n_pts2 = n_pts2.reshape(-1, 1, 2)

img2_mask = cv2.fillPoly(test_img2, [n_pts2], color=255)
cv2.imwrite("test2.png", img2_mask)

True

In [74]:
real_sorted_approx_poly1 = real_sorted_approx_poly1[:-1]
real_sorted_approx_poly2 = real_sorted_approx_poly2[:-1]

In [75]:
if real_sorted_approx_poly1[0][1] < real_sorted_approx_poly1[1][1]:
    datas = [real_sorted_approx_poly1[0]]
    for idx in range(len(real_sorted_approx_poly1)-1, 0, -1):
        datas.append(real_sorted_approx_poly1[idx])
    real_sorted_approx_poly1 = datas

if real_sorted_approx_poly2[0][1] < real_sorted_approx_poly2[1][1]:
    datas = [real_sorted_approx_poly2[0]]
    for idx in range(len(real_sorted_approx_poly2)-1, 0, -1):
        datas.append(real_sorted_approx_poly2[idx])
    real_sorted_approx_poly2 = datas

In [76]:
point_on_mask1 = np.array(real_sorted_approx_poly1).astype(np.float32)
point_on_mask2 = np.array(real_sorted_approx_poly2).astype(np.float32)

point_on_mask1 = point_on_mask1.reshape(1, -1, 2)
point_on_mask2 = point_on_mask2.reshape(1, -1, 2)

In [77]:
tps = cv2.createThinPlateSplineShapeTransformer()
matches = [cv2.DMatch(k, k, 0) for k in range(len(point_on_mask1[0]))]

In [78]:
test1 = cv2.imread("test1.png")
origin_src = cv2.bitwise_and(img1, test1)
test2 = cv2.imread("test2.png")
dist_src = cv2.bitwise_and(img2, test2)

tps.estimateTransformation(point_on_mask2, point_on_mask1, matches)
warped_img = tps.warpImage(origin_src)

In [79]:
cv2.imwrite("result.jpg", warped_img)

True