# Masked Image Style Transfer with Optimal Transport Loss

In [163]:
!style_transfer sheep_resized.jpg starry_night.jpg sheep_mask_resized.jpg -s 512

Using devices: cuda:0
GPU 0 type: NVIDIA GeForce RTX 4060 Laptop GPU (compute 8.9)
GPU 0 RAM: 8188 MB
Loading model...
Processing content image (128x94)...
this is the self.content_layers [22]
target shape torch.Size([1, 512, 11, 16]) current_mask shape torch.Size([1, 3, 94, 128])
Processing style image (128x80)...
self.style_layers [1, 6, 11, 20, 29]
Size: 128x94, iteration: 1, loss: 9.34844                                       
Size: 128x94, iteration: 2, loss: 11.7869                                       
Size: 128x94, iteration: 3, loss: 10.7228                                       
Size: 128x94, iteration: 4, loss: 10.0344                                       
Size: 128x94, iteration: 5, loss: 9.54213                                       
Size: 128x94, iteration: 6, loss: 9.27094                                       
Size: 128x94, iteration: 7, loss: 9.01643                                       
Size: 128x94, iteration: 8, loss: 9.64683                                      

In [156]:
import IPython
import cv2
import numpy as np
from IPython.display import display, Image
from io import BytesIO
import PIL.Image
from PIL import Image
from IPython.display import Image as IPyImage, display, HTML

# Original Image and Mask

In [157]:
# Load an image
image_path = 'sheep.jpg'
image = Image.open(image_path)

# Resize the image to 512x512
resized_image = image.resize((512, 378), Image.LANCZOS)
resized_image.save("./sheep_resized.jpg", 'JPEG')

# Show the resized image
images_html = f"<table><tr>\
                <td><img src='sheep_resized.jpg'></td>\
                <td><img src='sheep_mask_resized.jpg'></td>\
                </tr></table>"
display(HTML(images_html))


0,1
,


 # Style Transfer with Optimal Transport Loss

In [164]:
# Load and resize the first image
image_path = 'sheep.jpg'
image = Image.open(image_path)
resized_image = image.resize((512, 378), Image.LANCZOS)
resized_image.save("sheep_resized.jpg", 'JPEG')

style_image_path = 'starry_night.jpg'
style_image = Image.open(style_image_path)
style_image = style_image.resize((512, 378), Image.LANCZOS)
style_image.save("resized_starry_night.jpg", 'JPEG')

# Load the second image, assuming it is already appropriately sized
output_image_path = 'out.png'
output_image = Image.open(output_image_path)

# Display both images side-by-side using HTML
images_html = f"<table><tr>\
                <td><img src='sheep_resized.jpg'></td>\
                <td><img src='resized_starry_night.jpg'></td>\
                <td><img src='out.png'></td>\
                </tr></table>"
display(HTML(images_html))

0,1,2
,,


# MSE-Style Loss, Wasserstein-2 Loss, and  Optimal Transport Loss

In [166]:
# loss Discussion

images_html = f"<table><tr>\
                <td><img src='out_MSEStyleLoss.png'></td>\
                <td><img src='out_W2Loss.png'></td>\
                <td><img src='out.png'></td>\
                </tr></table>"
display(HTML(images_html))



0,1,2
,,


# Naive Masekd Style Transfer and Mask Strategies

In [153]:
image_a = cv2.imread('sheep_resized.jpg')
image_b = cv2.imread('out.png')

mask = cv2.imread('sheep_mask_resized.jpg', cv2.IMREAD_GRAYSCALE)
mask = cv2.resize(mask, (image_a.shape[1], image_a.shape[0]))
masked_area_b = cv2.bitwise_and(image_b, image_b, mask=mask)
inverse_mask = cv2.bitwise_not(mask)
remaining_area_a = cv2.bitwise_and(image_a, image_a, mask=inverse_mask)

# naive blend merge 
result_image = cv2.add(remaining_area_a, masked_area_b)
result_image_pil = PIL.Image.fromarray(cv2.cvtColor(result_image, cv2.COLOR_BGR2RGB))
result_image_pil.save("./sheep_resized_maskedOnly.jpg", 'JPEG')
masked_area_a = cv2.bitwise_and(image_a, image_a, mask=mask)
result_image_pil = PIL.Image.fromarray(cv2.cvtColor(masked_area_a, cv2.COLOR_BGR2RGB))
result_image_pil.save("./naive_blend.jpg", 'JPEG')


# seamless Clone
mask = cv2.imread('sheep_mask_resized.jpg', cv2.IMREAD_GRAYSCALE)
mask = cv2.resize(mask, (image_a.shape[1], image_a.shape[0]))  
monoMaskImage = cv2.split(mask)[0] # reducing the mask to a monochrome
br = cv2.boundingRect(monoMaskImage) # bounding rect (x,y,width,height)
centerOfBR = (br[0] + br[2] // 2, br[1] + br[3] // 2)
result_image = cv2.seamlessClone(image_b, image_a, mask, centerOfBR, cv2.NORMAL_CLONE)
result_image_pil = PIL.Image.fromarray(cv2.cvtColor(result_image, cv2.COLOR_BGR2RGB))
result_image_pil.save("seamlessClone_naive_mask.jpg", 'JPEG')



# plot
images_html = f"<table><tr>\
                <td><img src='naive_blend.jpg'></td>\
                <td><img src='sheep_resized_maskedOnly.jpg'></td>\
                <td><img src='seamlessClone_naive_mask.jpg'></td>\
                </tr></table>"
display(HTML(images_html))




0,1,2
,,


# Masked Style Transfer and Mask Strategies

In [154]:

image_a = cv2.imread('sheep_resized.jpg')
image_b = cv2.imread('out_maksedOnly.png')

# naive mask
mask = cv2.imread('sheep_mask_resized.jpg', cv2.IMREAD_GRAYSCALE)
mask = cv2.resize(mask, (image_a.shape[1], image_a.shape[0]))
masked_area_b = cv2.bitwise_and(image_b, image_b, mask=mask)
inverse_mask = cv2.bitwise_not(mask)
remaining_area_a = cv2.bitwise_and(image_a, image_a, mask=inverse_mask)
result_image = cv2.add(remaining_area_a, masked_area_b)

result_image_pil = PIL.Image.fromarray(cv2.cvtColor(result_image, cv2.COLOR_BGR2RGB))
result_image_pil.save("naive_blend_maskedOnly.jpg", 'JPEG')


# Samless Clone
image_a = cv2.imread('sheep_resized.jpg')
image_b = cv2.imread('out_maksedOnly.png')
mask = cv2.imread('sheep_mask_resized.jpg', cv2.IMREAD_GRAYSCALE)
mask = cv2.resize(mask, (image_a.shape[1], image_a.shape[0]))  
monoMaskImage = cv2.split(mask)[0] # reducing the mask to a monochrome
br = cv2.boundingRect(monoMaskImage) # bounding rect (x,y,width,height)
centerOfBR = (br[0] + br[2] // 2, br[1] + br[3] // 2)
# poissonImage = cv2.seamlessClone(srcImage, dstImage, maskImage, centerOfBR )
result_image = cv2.seamlessClone(image_b, image_a, mask, centerOfBR, cv2.NORMAL_CLONE)
result_image_pil = PIL.Image.fromarray(cv2.cvtColor(result_image, cv2.COLOR_BGR2RGB))
result_image_pil.save("seamlessClone_maskedOnly.jpg", 'JPEG')



# plot
images_html = f"<table><tr>\
                <td><img src='out_maksedOnly.png'></td>\
                <td><img src='naive_blend_maskedOnly.jpg'></td>\
                <td><img src='seamlessClone_maskedOnly.jpg'></td>\
                </tr></table>"
display(HTML(images_html))




0,1,2
,,


# Reference
Optimal Transport Loss - https://www.youtube.com/watch?v=ZFYZFlY7lgI&t=610s