In [5]:
from typing import Callable, Tuple
from finite_distributions.FiniteDistribution import FiniteDistribution
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import sinkhorn.SinkhornRunner as SinkhornRunner
import sinkhorn.SinkhornKernels as skern
import visualizer.joint_distribution_visualizer as jdv
import pandas as pd

from core.require import require

from PIL import Image, ImageDraw

In [6]:
# Load the image (convert to grayscale or keep RGB depending on your need)
# img_1 = Image.open("images/white_circle_on_black.png").convert("L")  # Use "RGB" if you want color
# img_2 = Image.open("images/white_square_on_black.png").convert("L")  # Use "RGB" if you want color
img_1 = Image.open("images/black_square_on_white.png").convert("L")  # Use "RGB" if you want color
img_2 = Image.open("images/black_circle_on_white.png").convert("L")  # Use "RGB" if you want color
# img_1 = Image.open("images/white_square_on_black.png").convert("L")  # Use "RGB" if you want color
# img_2 = Image.open("images/white_circle_on_black.png").convert("L")  # Use "RGB" if you want color



# Convert to NumPy array
img_1_array = np.array(img_1)/255
img_2_array = np.array(img_2)/255

require(img_1_array.shape == img_2_array.shape)
require(len(img_1_array.shape) == 2)

rows = img_1_array.shape[0]
cols = img_1_array.shape[1]
require(rows == cols)

keys = [(x, y) for x in range(rows) for y in range(cols)]

np.array([[1, 2]]).shape

# create distributions
__sum_1 = img_1_array.sum()
# img_1_distribution = FiniteDistribution({rows * x + y: img_1_array[x][y]/__sum_1 for (x, y) in keys})
img_1_distribution = FiniteDistribution({(x, y): img_1_array[x][y]/__sum_1 for (x, y) in keys})
img_1_distribution_reduced = FiniteDistribution({x: v for (x, v) in img_1_distribution.elementMapping.items() if v > 0})

__sum_2 = img_2_array.sum()
# img_2_distribution = FiniteDistribution({rows * x + y: img_2_array[x][y]/__sum_2 for (x, y) in keys})
img_2_distribution = FiniteDistribution({(x, y): img_2_array[x][y]/__sum_2 for (x, y) in keys})
img_2_distribution_reduced = FiniteDistribution({x: v for (x, v) in img_2_distribution.elementMapping.items() if v > 0})


In [7]:
c = lambda x, y: (x[0] - y[0])**2 + (x[1] - y[1])**2

In [8]:
# p-norm
# p = 2.
sinkhorn_runner = skern.get_quadratically_regularized_runner(c, use_parallelization=True)
# entropic
# sinkhorn_runner_entropic = skern.get_entropically_regularized_runner(c)

epsilons = [0.001]#, 0.5, 1.0, 5.0, 10.0, 100.0]
delta = 0.05 # lower delta for entropic


for epsilon in reversed(epsilons):
    print("Running quadratic.")
    pi, f, g, inner_p, outer_p = sinkhorn_runner.run_sinkhorn(img_1_distribution, img_2_distribution, epsilon, delta, dual_potential_precision_mult = 0.5, printInfo= True)
    print(f"Ran quadratic. Took {outer_p} outer iterations.")

    # print(f"Running entropic for epsilon = {epsilon}.")
    # pi, f, g, inner_e, outer_e = sinkhorn_runner_entropic.run_sinkhorn(img_1_distribution_reduced, img_2_distribution_reduced, epsilon, delta, printInfo=True)
    # print(f"Ran entropic. Took {outer_e} outer iterations.")

    # interpolate two images geodesically
    pi_element_mapping = pi.elementMapping
    all_interpolations = [{"x": int(0.5 * x0 + 0.5 * x1), "y": int(0.5 * y0 + 0.5 * y1), "p": p} for ((x0, y0), (x1, y1)), p in pi_element_mapping.items()]
    all_interpolations = all_interpolations + [{"x": int(x), "y": int(y), "p": 0.0} for (x, y) in img_1_distribution.get_keys()]
    pi_pandas = pd.DataFrame(all_interpolations)
    aggregated = pi_pandas.groupby(["x", "y"]).sum()
    max_weight = aggregated["p"].max()
    aggregated["p"] = 255 * aggregated["p"] / max_weight
    # add in elements
    composite_img_array = aggregated.unstack().values

    composite_img = Image.fromarray(composite_img_array.astype(np.uint8))
    composite_img.show()
    # composite_img.save(f"images/TEST_black_composite_img_p_entropic_epsilon_{round(epsilon, 3)}.png")
    composite_img.save(f"images/black_composite_img_p_epsilon_{round(epsilon, 4)}.png")
    # composite_img.save(f"images/white_composite_img_p_entropic_epsilon_{round(epsilon, 3)}.png")
    # composite_img.save(f"images/white_composite_img_p_epsilon_{round(epsilon, 3)}.png")

Running quadratic.
Prior outer iterations: 1. inner iterations: 0.
Iterations for g: 4403200.0
Iterations for f: 4403200.0
Error: 0.13018735894320327
outer iterations: 1. inner iterations: 8799977.4720001. Error: 0.13018735894320327
Prior outer iterations: 2. inner iterations: 8799977.4720001.
Iterations for g: 4403153.973810203
Iterations for f: 4403153.973810203
Error: 0.1292627855146531
outer iterations: 2. inner iterations: 17599954.944000117. Error: 0.1292627855146531
Prior outer iterations: 3. inner iterations: 17599954.944000117.
Iterations for g: 4403107.947620392
Iterations for f: 4403107.947620392
Error: 0.12823076273230527
outer iterations: 3. inner iterations: 26399932.416000217. Error: 0.12823076273230527
Prior outer iterations: 4. inner iterations: 26399932.416000217.
Iterations for g: 4403061.921430681
Iterations for f: 4403061.921430681
Error: 0.12724794513523657
outer iterations: 4. inner iterations: 35199909.888000324. Error: 0.12724794513523657
Prior outer iterations

In [5]:
# import matplotlib.patches as patches

# # Set up figure size to match 64x64 pixels
# dpi = 100
# size = 64
# figsize = (size / dpi, size / dpi)

# fig, ax = plt.subplots(figsize=figsize, dpi=dpi)

# # Add white circle
# circle = patches.Circle((3*size/8, 3*size/8), size/8, facecolor='white')
# ax.add_patch(circle)

# # # Add black square
# # square = patches.Rectangle((size/2, size/2), size/4, size/4, facecolor='black')
# # ax.add_patch(square)

# # Formatting
# ax.set_xlim(0, size)
# ax.set_ylim(0, size)
# ax.set_aspect('equal')
# ax.axis('off')

# # Save to file
# plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
# # plt.savefig('images/black_square_on_white_large.png', dpi=dpi, bbox_inches='tight', pad_inches=0)
# plt.savefig('images/white_circle_on_black_large.png', dpi=dpi, bbox_inches='tight', pad_inches=0)
# plt.close()

In [6]:
# import numpy as np
# from PIL import Image, ImageDraw

# # Create a white background image (mode 'L' = grayscale, 255 = white)
# img_size = 64
# img = Image.new('L', (img_size, img_size), color=0)

# # Draw a black filled circle
# draw = ImageDraw.Draw(img)
# circle_radius = size/8
# center = (3*size//8, 5*size//8)
# bbox = [
#     center[0] - circle_radius,
#     center[1] - circle_radius,
#     center[0] + circle_radius,
#     center[1] + circle_radius
# ]
# draw.ellipse(bbox, fill=255)  # 0 = black

# # Save the image
# img.save("images/white_circle_on_black.png")
# img.show()

In [7]:
# import numpy as np
# from PIL import Image, ImageDraw

# # Create a white background image (mode 'L' = grayscale, 255 = white)
# img_size = 64
# img = Image.new('L', (img_size, img_size), color=0)

# # Draw a black filled circle
# draw = ImageDraw.Draw(img)
# square_size = size/4
# top_left = (size // 2, size//4)
# bottom_right = (top_left[0] + square_size, top_left[1] + square_size)
# draw.rectangle([top_left, bottom_right], fill=255)  # 0 = black

# # Save the image
# img.save("images/white_square_on_black.png")
# img.show()