In [28]:
from typing import Callable, Tuple
from finite_distributions.FiniteDistribution import FiniteDistribution
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import sinkhorn.SinkhornRunner as SinkhornRunner
import sinkhorn.SinkhornKernels as skern
import visualizer.joint_distribution_visualizer as jdv
import pandas as pd

from core.require import require

from PIL import Image, ImageDraw

In [29]:
# Load the image (convert to grayscale or keep RGB depending on your need)
# img_1 = Image.open("images/white_circle_on_black.png").convert("L")  # Use "RGB" if you want color
# img_2 = Image.open("images/white_square_on_black.png").convert("L")  # Use "RGB" if you want color
img_1 = Image.open("images/black_square_on_white_large.png").convert("L")  # Use "RGB" if you want color
img_2 = Image.open("images/black_circle_on_white_large.png").convert("L")  # Use "RGB" if you want color


# Convert to NumPy array
img_1_array = np.array(img_1)/255
img_2_array = np.array(img_2)/255

require(img_1_array.shape == img_2_array.shape)
require(len(img_1_array.shape) == 2)

rows = img_1_array.shape[0]
cols = img_1_array.shape[1]
require(rows == cols)

keys = [(x, y) for x in range(rows) for y in range(cols)]

np.array([[1, 2]]).shape

# create distributions
__sum_1 = img_1_array.sum()
# img_1_distribution = FiniteDistribution({rows * x + y: img_1_array[x][y]/__sum_1 for (x, y) in keys})
img_1_distribution = FiniteDistribution({(x, y): img_1_array[x][y]/__sum_1 for (x, y) in keys})

__sum_2 = img_2_array.sum()
# img_2_distribution = FiniteDistribution({rows * x + y: img_2_array[x][y]/__sum_2 for (x, y) in keys})
img_2_distribution = FiniteDistribution({(x, y): img_2_array[x][y]/__sum_2 for (x, y) in keys})

In [30]:
c = lambda x, y: (x[0] - y[0])**2 + (x[1] - y[1])**2

In [38]:
# p-norm
p = 2.
sinkhorn_runner = skern.get_quadratically_regularized_runner(c, use_parallelization=True)
# entropic
# sinkhorn_runner_entropic = skern.get_entropically_regularized_runner(c)


epsilon = 0.1
delta = 0.01

print("Running quadratic.")
pi_p, f_p, g_p, inner_p, outer_p = sinkhorn_runner.run_sinkhorn(img_1_distribution, img_2_distribution, epsilon, delta, dual_potential_precision_mult = 0.5, printInfo= True)
print(f"Ran quadratic. Took {outer_p} outer iterations.")
# print("Running entropic.")
# pi_e, f_e, g_e, inner_e, outer_e = sinkhorn_runner_entropic.run_sinkhorn(img_1_distribution, img_2_distribution, epsilon, delta)
# print(f"Ran entropic. Took {outer_e} outer iterations.")

Running quadratic.
Prior outer iterations: 1. inner iterations: 0.
Iterations for g: 4403200.0
Iterations for f: 4403200.0
Error: 0.1283347591299048
outer iterations: 1. inner iterations: 8742320.355556028. Error: 0.1283347591299048
Prior outer iterations: 2. inner iterations: 8742320.355556028.
Iterations for g: 4402688.597890975
Iterations for f: 4402688.597890975
Error: 0.11459972480545293
outer iterations: 2. inner iterations: 17484640.711111166. Error: 0.11459972480545293
Prior outer iterations: 3. inner iterations: 17484640.711111166.
Iterations for g: 4402177.1957818195
Iterations for f: 4402177.1957818195
Error: 0.10691557839706421
outer iterations: 3. inner iterations: 26226960.403892793. Error: 0.10691557839706421
Prior outer iterations: 4. inner iterations: 26226960.403892793.
Iterations for g: 4401678.526346107
Iterations for f: 4401678.526346107
Error: 0.10041647484152227
outer iterations: 4. inner iterations: 34969242.68633338. Error: 0.10041647484152227
Prior outer itera

In [39]:
# interpolate two images geodesically
pi_element_mapping = pi_p.elementMapping
pi_pandas = pd.DataFrame([{"x": int(0.5 * x0 + 0.5 * x1), "y": int(0.5 * y0 + 0.5 * y1), "p": p} for ((x0, y0), (x1, y1)), p in pi_element_mapping.items()])
aggregated = pi_pandas.groupby(["x", "y"]).sum()
max_weight = aggregated["p"].max()
aggregated["p"] = 255 * aggregated["p"] / max_weight
composite_img_array = aggregated.unstack().values

In [40]:
composite_img = Image.fromarray(composite_img_array.astype(np.uint8))
composite_img.show()
composite_img.save(f"images/composite_img_p_epsilon_{round(epsilon, 3)}.png")

In [42]:
composite_img_array / np.sum(composite_img_array)

array([[1.73001539e-04, 2.97050239e-04, 2.94175166e-04, ...,
        3.10396673e-04, 2.49786387e-04, 6.30363785e-05],
       [2.97127874e-04, 4.55717528e-04, 3.85241804e-04, ...,
        4.32456942e-04, 4.09585473e-04, 1.10940816e-04],
       [2.93999164e-04, 3.85500779e-04, 2.93979597e-04, ...,
        3.36425479e-04, 3.75125572e-04, 1.15001376e-04],
       ...,
       [3.14114553e-04, 4.40482471e-04, 3.45117129e-04, ...,
        3.91745582e-04, 4.13862489e-04, 1.19243275e-04],
       [2.51861531e-04, 4.14407620e-04, 3.81254194e-04, ...,
        4.14050829e-04, 3.56273646e-04, 9.24952273e-05],
       [6.34631003e-05, 1.11914889e-04, 1.16518823e-04, ...,
        1.19356262e-04, 9.25138171e-05, 2.29256498e-05]])

In [34]:
# import matplotlib.patches as patches

# # Set up figure size to match 64x64 pixels
# dpi = 100
# size = 64
# figsize = (size / dpi, size / dpi)

# fig, ax = plt.subplots(figsize=figsize, dpi=dpi)

# # Add black circle
# circle = patches.Circle((3*size/8, 3*size/8), size/8, facecolor='black')
# ax.add_patch(circle)

# # # Add black square
# # square = patches.Rectangle((size/2, size/2), size/4, size/4, facecolor='black')
# # ax.add_patch(square)

# # Formatting
# ax.set_xlim(0, size)
# ax.set_ylim(0, size)
# ax.set_aspect('equal')
# ax.axis('off')

# # Save to file
# plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
# # plt.savefig('images/black_square_on_white_large.png', dpi=dpi, bbox_inches='tight', pad_inches=0)
# plt.savefig('images/black_circle_on_white_large.png', dpi=dpi, bbox_inches='tight', pad_inches=0)
# plt.close()