## Mandatory Imports from auxiliary libraries and custom implementation

In [None]:
import torch
from torch.autograd        import grad
from pykeops.torch         import Kernel

import json
import numpy as np
from scipy import misc
from scipy.ndimage.filters import gaussian_filter
from matplotlib import pyplot as plt
from time import time
import extract_coordinates_json as json_coord
import extract_svg_coordinates as svg_coord
from statsmodels import robust
from common.sinkhorn_balanced import sinkhorn_divergence

## Deciding to whether use CUDA (GPU) or normal Float Tensor of PyTorch and defining raw path files

In [None]:
use_cuda = torch.cuda.is_available()
tensor   = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor

plt.ion()
plt.show()

s2v = lambda x: tensor([x])

svg_path_shape1 = r'/home/yash/Desktop/HaTran/Shape1_start_absolute.svg'
svg_path_shape2 = r'/home/yash/Desktop/HaTran/Shape2_start_absolute.svg'
svg_path_shape3 = r'/home/yash/Desktop/HaTran/Shape3_start_absolute.svg'

json_path = r'/home/yash/Desktop/HaTran/patient_data.json'

## Defining what kind of routines to run

In [None]:
experiments = {}

if True:  # Sinkhorn
    for p in [2]:  # C(x,y) = |x-y|^1 or |x-y|^2
        for eps, eps_s in [(.01, "S")]:
            for nits in [2]:
                experiments["sinkhorn_L{}_{}_{}its".format(p, eps_s, nits)] = {
                    "formula": "sinkhorn",
                    "p": p,
                    "eps": eps ** p,  # Remember : eps is homogeneous to C(x,y)
                    "nits": nits,
                    "tol": 0.,  # Run all iterations, no early stopping!
                    "transport_plan": "heatmaps",
                }

## Loading an image (png) as a PyTorch tensor

In [None]:
def LoadImage(fname):
    img = misc.imread(fname, flatten = True)    # Grayscale
    img = gaussian_filter(img, 1, mode='nearest')    # Applying Gaussian filter to blur the image which would 
                                                     # smoothen the image and will be of great help while using
                                                     # autodifferentiaition as smooth gradients tend to good converegence.
#     plt.imshow(img)
    img = (img[::-1,:]) / 255.      # Normalizing the image
    img = np.swapaxes(img, 0, 1)    # Taking transpose as tensors are always stored as transpose (column vectors) of a function
    return tensor (1 - img)

## Dataset and some Macros

In [None]:
dataset = "shape"
datasets = {
    "shape": ("data/patient_input_0.png", "data/shape1.png"),
}
    

# Note that both measures will be normalized in "sparse_distance_bmp"
source = LoadImage(datasets[dataset][0])
target = LoadImage(datasets[dataset][1])

print(source.shape)

# The images are rescaled to fit into the unit square 
scale = source.shape[0]
affine = tensor([[1, 0, 0], [0, 1, 0]]) / scale

# We'll save the output wrt. the number of iterations
display = False

## Extracting a Point Cloud of an Image

In [None]:
def extract_point_cloud(I, affine):

    # Threshold, to extract the relevant indices 
    ind = (I > .001).nonzero()
    
    # Extract the weights 
    D = len(I.shape)
    if   D == 2 : α_i = I[ind[:,0], ind[:,1]]    # weights of the non-zero pixel indices
    elif D == 3 : α_i = I[ind[:,0], ind[:,1], ind[:,2]]    # weights of the non-zero pixel indices
    else : raise NotImplementedError()

    α_i = α_i * affine[0, 0] * affine[1, 1] # Lazy approximation of the determinant...
    # If we normalize the measures, it doesn't matter anyway.

    # Don't forget the changes of coordinates! 
    M   = affine[:D,:D] ; off = affine[:D,D]
    x_i = ind.float() @ M.t() + off

    return ind, α_i.view(-1, 1), x_i

## Function to calculate the sparse distance 

In [None]:
def sparse_distance_bmp(params, A, B, affine_A, affine_B, normalize=True, info=False, action="measure"):
    """
    Takes as input two torch bitmaps (Tensors). 
    Returns a cost and a gradient, encoded as a vector bitmap.

    Args :
        - A and B : two torch bitmaps (Tensors) of dimension D.
        - affine_A and affine_B : two matrices of size (D+1,D+1) (Tensors).
    """
    D = len(A.shape) # dimension of the ambient space, =2 for slices or =3 for volumes

    ind_A, α_i, x_i = extract_point_cloud(A, affine_A)
    ind_B, β_j, y_j = extract_point_cloud(B, affine_B)

    if normalize :
        α_i = α_i / α_i.sum()
        β_j = β_j / β_j.sum()

    x_i.requires_grad = True
    if action == "image" :
        α_i.requires_grad = True

    # Compute the distance between the *measures* A and B ------------------------------
    # print("{:,}-by-{:,} KP: ".format(len(x_i), len(y_j)), end='')

    routines = { 
        "sinkhorn"       : sinkhorn_divergence,    # This is blindly copied from the implementation 
                                                   # done by the authors.  
    }

    routine = routines[ params.get("formula", "sinkhorn") ]
    params["heatmaps"] = info
    cost, heatmaps = routine( α_i,x_i, β_j,y_j, **params )

    if action == "image" :
        grad_a, grad_x = grad( cost, [α_i, x_i] ) # gradient wrt the voxels' positions and weights
    elif action == "measure" :
        grad_x = grad( cost, [x_i] )[0] # gradient wrt the voxels' positions

    # Point cloud to bitmap (grad_x) ---------------------------------------------------
    tensor   = torch.cuda.FloatTensor if A.is_cuda else torch.FloatTensor 
    # Using torch.zero(...).dtype(cuda.FloatTensor) would be inefficient...
    # Let's directly make a "malloc", before zero-ing in place
    grad_A = tensor( *(tuple(A.shape) + (D,))  )
    grad_A.zero_()

    if action == "measure":
        if D == 2:
            grad_A[ind_A[:, 0], ind_A[:, 1], :] = grad_x[:, :]
        elif D == 3:
            grad_A[ind_A[:, 0], ind_A[:, 1], ind_A[:, 2], :] = grad_x[:, :]
        else:
            raise NotImplementedError()

    elif action == "image":
        if D == 2:
            if True:
                dim_0 = affine_A[0,0]; print(dim_0)
                grad_A[ind_A[:, 0], ind_A[:,1], :] += .25 * dim_0 * grad_x[:,:]
                grad_A[ind_A[:, 0] + 1, ind_A[:, 1], :] += .25 * dim_0 * grad_x[:,:]
                grad_A[ind_A[:, 0], ind_A[:,1]+1, :] += .25 * dim_0 * grad_x[:,:]
                grad_A[ind_A[:, 0] + 1, ind_A[:, 1] + 1, :] += .25 * dim_0 * grad_x[:,:]

            grad_a = grad_a[:] * α_i[:]
            grad_A[ind_A[:,0]  ,ind_A[:,1]  , 0] -= .5*grad_a[:]
            grad_A[ind_A[:,0]+1,ind_A[:,1]  , 0] += .5*grad_a[:]
            grad_A[ind_A[:,0]  ,ind_A[:,1]+1, 0] -= .5*grad_a[:]
            grad_A[ind_A[:,0]+1,ind_A[:,1]+1, 0] += .5*grad_a[:]

            grad_A[ind_A[:,0]  ,ind_A[:,1]  , 1] -= .5*grad_a[:]
            grad_A[ind_A[:,0]  ,ind_A[:,1]+1, 1] += .5*grad_a[:]
            grad_A[ind_A[:,0]+1,ind_A[:,1]  , 1] -= .5*grad_a[:]
            grad_A[ind_A[:,0]+1,ind_A[:,1]+1, 1] += .5*grad_a[:]
 
            if False :
                grad_A[ind_A[:,0]  ,ind_A[:,1]  , 0] = grad_a[:]
                grad_A[ind_A[:,0]  ,ind_A[:,1]  , 1] = grad_a[:]
            
    # N.B.: we return "PLUS gradient", i.e. "MINUS a descent direction".
    return cost, grad_A.detach(), heatmaps

In [None]:
def calculate_score(name, params):
    t_0 = time()
    cost, grad_src, heatmaps = sparse_distance_bmp(params, source, target, 
                                                           affine, affine, 
                                                           normalize=True, info=display )
    t_1 = time()
    
    return float("{:.6f}".format(cost.item())) 

In [None]:
for name, params in experiments.items():
    cost = calculate_score(name=name, params=params)


In [None]:
cost