In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
from PIL import ImageDraw
import kornia.geometry
import cv2 
import matplotlib.pyplot as plt
import numpy as np
import plotly.graph_objects as go

# 1. Separate the three parts of the picture

This part was written by our assistant, Martin Everaert.

In [2]:

def translate_each_sector_to_center_soft(image, center, angle, move_distance):
    # Get image dimensions
    batch_size, channels, height, width = image.size()
    angles = angle + torch.Tensor([0.0, 2*np.pi/3, 4*np.pi/3])

    # Convert center coordinates to image coordinates
    center_x = center[0] * width
    center_y = center[1] * height

    # Create grid of coordinates
    grid_y, grid_x = torch.meshgrid(torch.arange(height), torch.arange(width))
    grid_x = grid_x.float().to(image.device)
    grid_y = grid_y.float().to(image.device)

    # Calculate angles of each pixel relative to the center
    pixel_angles = torch.atan2(grid_y - center_y, grid_x - center_x)
    pixel_angles = (pixel_angles + 2 * np.pi) % (2 * np.pi)


    transformed_images = []
    masks = []

    for i in range(3):

        # Compute the starting and ending angles for this sector
        start_angle = angles[i]
        end_angle = angles[i+1] if i<len(angles)-1 else angles[0]

        start_direction = torch.tensor([torch.cos(start_angle+np.pi/2), torch.sin(start_angle+np.pi/2)], dtype=torch.float32)
        end_direction = torch.tensor([torch.cos(end_angle+np.pi/2), torch.sin(end_angle+np.pi/2)], dtype=torch.float32)

        # Compute signed distances for each pixel
        point_vectors = torch.stack([grid_x - center_x, grid_y - center_y])
        start_signed_distance = torch.sum(point_vectors * start_direction[:, None, None], dim=0)
        end_signed_distance = torch.sum(point_vectors * end_direction[:, None, None], dim=0)

        # Apply sigmoid to create soft mask
        start_mask = torch.sigmoid(start_signed_distance/10-1)
        end_mask = torch.sigmoid(-end_signed_distance/10-1)
        mask = start_mask * end_mask


        if start_angle <= end_angle:
            angle = (start_angle + end_angle) / 2
        else:
            # Handling wrap-around when end_angle < start_angle, eg start = 2pi - 1 and end = 1
            angle = ((start_angle + end_angle + 2 * torch.pi) / 2) % (2 * torch.pi)

        # Compute the translation for this sector
        move_distance_scaled = move_distance*width

        M = kornia.geometry.get_affine_matrix2d(
            translations = torch.stack([-move_distance_scaled * torch.cos(angle), -move_distance_scaled  * torch.sin(angle)], dim=1),
            center = torch.stack([center_x, center_y]).unsqueeze(0),
            scale = torch.ones(2).unsqueeze(0),
            angle = torch.zeros(1)
            )[:, :2, :]


        # Apply the mask and translation
        masked_image = image * mask.unsqueeze(0)
        transformed_image = kornia.geometry.warp_affine(masked_image, M, dsize=(height, width, ))
        transformed_mask = kornia.geometry.warp_affine(mask.unsqueeze(0).unsqueeze(0), M, dsize=(height, width, ))

        transformed_images.append(transformed_image)
        masks.append(transformed_mask)

    return transformed_images[0], transformed_images[1], transformed_images[2], masks[0], masks[1], masks[2]

In [3]:
class KaleidoscopeModelSoft(nn.Module):
    def __init__(self, center_init, angle_init, move_distance_init):
        super(KaleidoscopeModelSoft, self).__init__()
        # Define the parameters of the kaleidoscope effect
        self.center = torch.tensor(center_init) #nn.Parameter(torch.tensor(center_init))
        self.angle = nn.Parameter(torch.tensor([angle_init]))
        self.move_distance =  nn.Parameter(torch.tensor([move_distance_init]))

    def forward(self, x):
        return translate_each_sector_to_center_soft(x, self.center, self.angle, self.move_distance)

In [None]:
# Load the image using PIL
image_path = "pen.jpg"
image = Image.open(image_path)

# Convert the image to a PyTorch tensor
transform = transforms.ToTensor()
image_tensor = transform(image)

# Ensure the image tensor is in the correct shape and range
image_tensor = image_tensor.unsqueeze(0)  # Add batch dimension
image_tensor = image_tensor.float()  # Convert to float32
image_tensor /= 255.0  # Normalize pixel values to [0, 1]

# Create the kaleidoscope model
model = KaleidoscopeModelSoft(
    center_init = [0.6, 0.6],
    angle_init = 0.0,
    move_distance_init = 0.2,
)

# Define the loss function
criterion_l1 = nn.L1Loss()
criterion_l2 = nn.MSELoss()
alpha = 1e11
beta = 1
gamma = 1

# Define the optimizer
optimizer = optim.Adam(model.parameters(), lr=0.01)

# Training loop
num_epochs = 200
vizualise = True
for epoch in range(num_epochs):
    # Forward pass
    transformed_0, transformed_1, transformed_2, mask_0, mask_1, mask_2 = model(image_tensor)

    # Compute the loss
    loss_img = alpha * criterion_l2(mask_0*mask_1*transformed_0, mask_0*mask_1*transformed_1)/torch.sum(mask_0*mask_1)
    loss_img += alpha * criterion_l2(mask_0*mask_2*transformed_0, mask_0*mask_2*transformed_2)/torch.sum(mask_0*mask_2)
    loss_img += alpha * criterion_l2(mask_1*mask_2*transformed_1, mask_1*mask_2*transformed_2)/torch.sum(mask_1*mask_2)

    # Some regularization terms
    loss_reg_center = beta * criterion_l2(model.center, 0.5*torch.ones(2))
    loss_reg_move_distance = gamma * criterion_l2(model.move_distance, 0.2*torch.ones(1))

    loss = loss_img + loss_reg_center + loss_reg_move_distance

    # Backward pass
    optimizer.zero_grad()
    loss.backward()
    #print(model.angle.grad)
    optimizer.step()

    # Print progress
    print(f'Epoch [{epoch+1}/{num_epochs}], loss_img: {loss_img.item()}, loss_reg_center: {loss_reg_center.item()}, loss_reg_move_distance: {loss_reg_move_distance.item()}')
    print(model.center, model.angle, model.move_distance)


    if vizualise and epoch == 199:
        with torch.no_grad():
            transformed_images_0_pil = transforms.ToPILImage()(transformed_0.squeeze(0)*255)
            transformed_images_1_pil = transforms.ToPILImage()(transformed_1.squeeze(0)*255)
            transformed_images_2_pil = transforms.ToPILImage()(transformed_2.squeeze(0)*255)
            mask_0_pil = transforms.ToPILImage()(mask_0.squeeze(0))
            mask_1_pil = transforms.ToPILImage()(mask_1.squeeze(0))
            mask_2_pil = transforms.ToPILImage()(mask_2.squeeze(0))

            transformed_images_total = torch.stack([transformed_0, transformed_1, transformed_2], dim=1).sum(dim=1)
            transformed_images_total /= transformed_images_total.max()
            transformed_images_total_pil = transforms.ToPILImage()(transformed_images_total.squeeze(0))

            # Convert the image tensor to a PIL image
            original_image_pil = transforms.ToPILImage()(image_tensor.squeeze(0)*255)

            # Create a draw object
            draw = ImageDraw.Draw(original_image_pil)

            # Define colors for drawing
            boundary_color = (0, 255, 0)  # Green color for boundaries


            center_x, center_y = int(model.center[0] * original_image_pil.width), int(model.center[1] * original_image_pil.height)

            # Draw boundaries
            for angle_tmp in model.angle + torch.Tensor([0.0, 2*np.pi/3, 4*np.pi/3]):
                x = center_x + model.move_distance.item() * np.cos(angle_tmp.item()) * original_image_pil.width
                y = center_y + model.move_distance.item() * np.sin(angle_tmp.item()) * original_image_pil.width
                draw.line([(center_x, center_y), (x, y)], fill=boundary_color, width=2)

            display(original_image_pil) # Display the original image with center point and boundaries
            display(transformed_images_total_pil)

In [5]:
# Save the images
transformed_images_0_pil.save("transformed_image_0.jpg")
transformed_images_1_pil.save("transformed_image_1.jpg")
transformed_images_2_pil.save("transformed_image_2.jpg")

# 2. Find matching points

In [6]:
img2 = cv2.imread('transformed_image_2.jpg',cv2.IMREAD_GRAYSCALE) 
img1 = cv2.imread('transformed_image_1.jpg',cv2.IMREAD_GRAYSCALE)

# Find keypoints with ORB from opencv
orb = cv2.ORB_create()
kp1, des1 = orb.detectAndCompute(img1,None)
kp2, des2 = orb.detectAndCompute(img2,None)
 
# Create matches with BFMatcher
bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True)
matches = bf.match(des1,des2)

In [None]:
def filter_horizontal_matches(matches, keypoints1, keypoints2, threshold=5):
    good_matches = []
    
    for match in matches:
        pt1 = np.array(keypoints1[match.queryIdx].pt)
        pt2 = np.array(keypoints2[match.trainIdx].pt)
        
        # Check if the y-coordinates are within the threshold
        if abs(pt1[1] - pt2[1]) < threshold:
            length = np.linalg.norm(pt2 - pt1)
            good_matches.append((match, length))
            
    return good_matches


# Filter horizontal matches and calculate their lengths
horizontal_matches_with_lengths = filter_horizontal_matches(matches, kp1, kp2)
horizontal_matches = [match for match, length in horizontal_matches_with_lengths]
lengths = [length for match, length in horizontal_matches_with_lengths]

# Draw the matches
img_matches = cv2.drawMatches(img1, kp1, img2, kp2, horizontal_matches, None, flags=cv2.DrawMatchesFlags_NOT_DRAW_SINGLE_POINTS)


for match, length in horizontal_matches_with_lengths:
    pt1 = np.array(kp1[match.queryIdx].pt)
    pt2 = np.array(kp2[match.trainIdx].pt) + np.array([img1.shape[1], 0]) 
    mid_point = (pt1 + pt2) / 2
    cv2.putText(img_matches, f'{length:.2f}', tuple(mid_point.astype(int)),cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 1, cv2.LINE_AA)

# Display the images with epipolar lines
plt.figure(figsize=(20, 10))
plt.imshow(img_matches)
plt.show()



# 3. Plot keypoints with relative depth

In [None]:
fig = go.Figure()

# Add the original image as background
fig.add_trace(go.Image(z=transformed_images_1_pil))

for match, length in horizontal_matches_with_lengths:
    pt1 = np.array(kp1[match.queryIdx].pt)
    fig.add_trace(go.Scatter(x=[pt1[0]], y=[pt1[1]],
                             mode='markers',
                             marker=dict(color='red', size=5),
                             hoverinfo='text',
                             text=f'depth: {(length-1)*5:.2f}', # Transform lengths to depth
                             showlegend=False))

# Update layout
fig.update_layout(
    title="Keypoints with Match Relative Depth",
    hovermode='closest',
    showlegend=False,
    height=img1.shape[0] * 3 //2 
)

fig.show()
