In [1]:
import numpy as np
import kornia
import torch
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt

def CarvingHelper(input_image, num_seams_to_remove):
    if num_seams_to_remove <= 0:
        return input_image

    # Converting image to tensor and float
    image_tensor = transforms.ToTensor()(input_image).unsqueeze(0).float()

    # Energy image E
    spatial_gradient = kornia.filters.SpatialGradient(mode='sobel')(image_tensor)
    energy_E = torch.sqrt(spatial_gradient[:, 0]**2 + spatial_gradient[:, 1]**2).sum(dim=1)

    image_height, image_width = energy_E.shape[1:]

    # Creating Scoring Matrix
    scoring_matrix = torch.full((image_height, image_width), float('inf'))

    # First row of the scoring matrix
    scoring_matrix[0] = energy_E.squeeze(0)[0]

    for current_row in range(1, image_height):
        for current_col in range(image_width):
            if current_col == 0:
                min_previous = min(scoring_matrix[current_row-1, current_col], scoring_matrix[current_row-1, current_col+1])
            elif current_col == image_width - 1:
                min_previous = min(scoring_matrix[current_row-1, current_col], scoring_matrix[current_row-1, current_col-1])
            else:
                min_previous = min(scoring_matrix[current_row-1, current_col-1], scoring_matrix[current_row-1, current_col], scoring_matrix[current_row-1, current_col+1])
            
            scoring_matrix[current_row, current_col] = energy_E.squeeze(0)[current_row, current_col] + min_previous
            
    # After removing seams
    carved_image = np.array(input_image)

    # Column with the smallest value in the bottom row
    for _ in range(num_seams_to_remove):
        min_col_index = torch.argmin(scoring_matrix[-1]).item()

        # Backtrack the seam
        seam_coordinates = []
        for backtrack_row in reversed(range(image_height)):
            seam_coordinates.append((backtrack_row, min_col_index))
            if backtrack_row == 0:
                break
            if min_col_index == 0:
                column_change = np.argmin([scoring_matrix[backtrack_row-1, min_col_index], scoring_matrix[backtrack_row-1, min_col_index+1]])
            elif min_col_index == image_width - 1:
                column_change = np.argmin([scoring_matrix[backtrack_row-1, min_col_index-1], scoring_matrix[backtrack_row-1, min_col_index]])
            else:
                column_change = np.argmin([scoring_matrix[backtrack_row-1, min_col_index-1], scoring_matrix[backtrack_row-1, min_col_index], scoring_matrix[backtrack_row-1, min_col_index+1]]) - 1
            min_col_index += column_change

            min_col_index = max(0, min(min_col_index, image_width - 1))

        # Removal of seam
        seam_mask = np.ones(carved_image.shape, dtype=bool)
        for coordinate_pair in seam_coordinates:
            seam_mask[coordinate_pair[0], coordinate_pair[1], :] = False
        carved_image = carved_image[seam_mask].reshape((image_height, image_width-1, 3))

        # Decrease width and adjust scoring matrix
        image_width -= 1
        scoring_matrix = scoring_matrix[:, :-1]

    return Image.fromarray(carved_image)

def MySeamCarving(img, target_width, target_height):
    vertical_seams_remove = img.width - target_width
    horizontal_seams_remove = img.height - target_height

    # Remove vertical seams
    img = CarvingHelper(img, vertical_seams_remove)

    # Transpose image
    img = img.transpose(method=Image.Transpose.ROTATE_270)
    
    # Remove horizontal seams (after transpose)
    img = CarvingHelper(img, horizontal_seams_remove)

    # Transpose back to original orientation
    img = img.transpose(method=Image.Transpose.ROTATE_90)

    return img