# TensorFlow's `sparse_image_warp` in Pytorch

In [1]:
# first call `%store spectro` in the other notebook to save spectro value here
%store -r spectro

In [2]:
import torch
import numpy as np

In [3]:
# source_control_point_locations: `[batch, num_control_points, 2]` float
#   `Tensor`
# dest_control_point_locations: `[batch, num_control_points, 2]` float
#   `Tensor`

In [4]:
num_dimensions = 2

In [175]:
# paper says 'a random point' so we'll use 1 num_control_point
src_pts, dest_pts = torch.tensor([[[1,3]]]), torch.tensor([[[1,6]]])

In [165]:
spectro.shape

torch.Size([1, 128, 1718])

In [166]:
def sparse_image_warp(img_tensor:torch.Tensor,
                      source_control_point_locations:torch.Tensor,
                      dest_control_point_locations:torch.Tensor,
                      interpolation_order=2,
                      regularization_weight=0.0,
                      num_boundaries_points=0):
    control_point_flows = (dest_control_point_locations - source_control_point_locations)
    
#     clamp_boundaries = num_boundary_points > 0
#     boundary_points_per_edge = num_boundary_points - 1
    batch_size, image_height, image_width = img_tensor.shape
    grid_locations = get_grid_locations(image_height, image_width)
    flattened_grid_locations = flatten_grid_locations(grid_locations, image_height, image_width)

#     flattened_grid_locations = constant_op.constant(
#         _expand_to_minibatch(flattened_grid_locations, batch_size), image.dtype)

#     if clamp_boundaries:
#       (dest_control_point_locations,
#        control_point_flows) = _add_zero_flow_controls_at_boundary(
#            dest_control_point_locations, control_point_flows, image_height,
#            image_width, boundary_points_per_edge)

    flattened_flows = interpolate_spline(
        dest_control_point_locations,
        control_point_flows,
        flattened_grid_locations,
        interpolation_order,
        regularization_weight)
    # CONTINUE FROM HERE AFTER  INTERPOLATE FUNC IS FINISHED
    return flattened_flows

    dense_flows = create_dense_flows(flattened_flows, batch_size, image_height, image_width)

    warped_image = dense_image_warp(image, dense_flows)

    return warped_image, dense_flows

In [167]:
def get_grid_locations(image_height, image_width):
    """Wrapper for np.meshgrid."""

    y_range = np.linspace(0, image_height - 1, image_height)
    x_range = np.linspace(0, image_width - 1, image_width)
    y_grid, x_grid = np.meshgrid(y_range, x_range, indexing='ij')
    return np.stack((y_grid, x_grid), -1)

In [168]:
def flatten_grid_locations(grid_locations, image_height, image_width):
    return np.reshape(grid_locations, [image_height * image_width, 2])

In [169]:
def create_dense_flows(flattened_flows, batch_size, image_height, image_width):
    # possibly .view
    return torch.reshape(flattened_flows, [batch_size, image_height, image_width, 2])

## Interpolate Spline

In [170]:
def interpolate_spline(train_points:torch.Tensor, train_values:torch.Tensor, query_points:torch.Tensor, order, regularization_weight=0.0,):
    # First, fit the spline to the observed data.
    w, v = solve_interpolation(train_points, train_values, order, regularization_weight)
    # Then, evaluate the spline at the query locations.
    query_values = apply_interpolation(query_points, train_points, w, v, order)

    return query_values

In [222]:
sparse_image_warp(spectro, src_pts, dest_pts)

Matrix A tensor([[[-0.]]]) torch.Size([1, 1, 1])
Matrix B tensor([[[1., 6., 1.]]]) torch.Size([1, 1, 3])
Left Block tensor([[[-0.],
         [1.],
         [6.],
         [1.]]]) torch.Size([1, 4, 1])
Num_B_Cols torch.Size([1, 1, 3])
Right Block tensor([[[ 1.0000e+00,  6.0000e+00,  1.0000e+00],
         [-1.7422e-11,  1.8778e-10, -7.2305e-11],
         [-1.3256e-10,  2.6914e-11, -1.7892e-10],
         [-8.3804e-12,  9.6467e-11, -2.5977e-11]]]) torch.Size([1, 4, 3])
LHS tensor([[[-0.0000e+00,  1.0000e+00,  6.0000e+00,  1.0000e+00],
         [ 1.0000e+00, -1.7422e-11,  1.8778e-10, -7.2305e-11],
         [ 6.0000e+00, -1.3256e-10,  2.6914e-11, -1.7892e-10],
         [ 1.0000e+00, -8.3804e-12,  9.6467e-11, -2.5977e-11]]]) torch.Size([1, 4, 4])
RHS tensor([[[0., 3.],
         [0., 0.],
         [0., 0.],
         [0., 0.]]]) torch.Size([1, 4, 2])
tensor([[[ 0.0000e+00, -1.1849e-10]]])
tensor([[[  0.0000, -12.8733],
         [  0.0000,   1.6762],
         [ -0.0000,   5.8163]]])


TypeError: apply_interpolation() takes 0 positional arguments but 5 were given

In [223]:
def solve_interpolation(train_points, train_values, order, regularization_weight):
    b, n, d = train_points.shape
    k = train_values.shape[-1]

    # First, rename variables so that the notation (c, f, w, v, A, B, etc.)
    # follows https://en.wikipedia.org/wiki/Polyharmonic_spline.
    # To account for python style guidelines we use
    # matrix_a for A and matrix_b for B.
    
    c = train_points
    f = train_values.float()
    
    matrix_a = phi(cross_squared_distance_matrix(c,c), order).unsqueeze(0)  # [b, n, n]
    print('Matrix A', matrix_a, matrix_a.shape)
#     if regularization_weight > 0:
#         batch_identity_matrix = array_ops.expand_dims(
#           linalg_ops.eye(n, dtype=c.dtype), 0)
#         matrix_a += regularization_weight * batch_identity_matrix

    # Append ones to the feature values for the bias term in the linear model.
    ones = torch.ones(1, dtype=src_pts.dtype).view([-1, 1, 1])
    matrix_b = torch.cat((c, ones), 2).float()  # [b, n, d + 1]
    print('Matrix B', matrix_b, matrix_b.shape)

    # [b, n + d + 1, n]
    left_block = torch.cat((matrix_a, torch.transpose(matrix_b, 2, 1)), 1)
    print('Left Block', left_block, left_block.shape)

    num_b_cols = matrix_b.shape[2]  # d + 1
    print('Num_B_Cols', matrix_b.shape)
    #     lhs_zeros = torch.zeros((b, num_b_cols, num_b_cols), dtype=train_points.dtype).float()

    # In Tensorflow, zeros are used here. Pytorch gesv fails with zeros for some reason we don't understand.
    # So instead we use very tiny randn values (variance of one, zero mean) on one side of our multiplication.
    lhs_zeros = torch.randn((b, num_b_cols, num_b_cols)) / 1e10
    right_block = torch.cat((matrix_b, lhs_zeros),
                                   1)  # [b, n + d + 1, d + 1]
    print('Right Block', right_block, right_block.shape)
    lhs = torch.cat((left_block, right_block),
                           2)  # [b, n + d + 1, n + d + 1]
    print('LHS', lhs, lhs.shape)

    rhs_zeros = torch.zeros((b, d + 1, k), dtype=train_points.dtype).float()
    rhs = torch.cat((f, rhs_zeros), 1)  # [b, n + d + 1, k]
    print('RHS', rhs, rhs.shape)

    # Then, solve the linear system and unpack the results.
    X, LU = torch.gesv(rhs, lhs)
    w = X[:, :n, :]
    v = X[:, n:, :]

    return w, v

In [173]:
def cross_squared_distance_matrix(x, y):
    """Pairwise squared distance between two (batch) matrices' rows (2nd dim).
        Computes the pairwise distances between rows of x and rows of y
        Args:
        x: [batch_size, n, d] float `Tensor`
        y: [batch_size, m, d] float `Tensor`
        Returns:
        squared_dists: [batch_size, n, m] float `Tensor`, where
        squared_dists[b,i,j] = ||x[b,i,:] - y[b,j,:]||^2
    """
    x_norm_squared = torch.sum(torch.mul(x, x))
    y_norm_squared = torch.sum(torch.mul(y, y))

    x_y_transpose = torch.matmul(x.squeeze(0), y.squeeze(0).transpose(0,1))
    
    # squared_dists[b,i,j] = ||x_bi - y_bj||^2 = x_bi'x_bi- 2x_bi'x_bj + x_bj'x_bj
    squared_dists = x_norm_squared - 2 * x_y_transpose + y_norm_squared

    return squared_dists.float()

In [153]:
# for [1,3] a [1,6]
cross_squared_distance_matrix(src_pts, dest_pts)

tensor([[9.]])

In [154]:
def phi(r, order):
    """Coordinate-wise nonlinearity used to define the order of the interpolation.
    See https://en.wikipedia.org/wiki/Polyharmonic_spline for the definition.
    Args:
    r: input op
    order: interpolation order
    Returns:
    phi_k evaluated coordinate-wise on r, for k = r
    """
    EPSILON=torch.tensor(1e-10)
    # using EPSILON prevents log(0), sqrt0), etc.
    # sqrt(0) is well-defined, but its gradient is not
    if order == 1:
        r = torch.max(r, EPSILON)
        r = torch.sqrt(r)
        return r
    elif order == 2:
        return 0.5 * r * torch.log(torch.max(r, EPSILON))
    elif order == 4:
        return 0.5 * torch.square(r) * torch.log(torch.max(r, EPSILON))
    elif order % 2 == 0:
        r = torch.max(r, EPSILON)
        return 0.5 * torch.pow(r, 0.5 * order) * torch.log(r)
    else:
        r = torch.max(r, EPSILON)
        return torch.pow(r, 0.5 * order)

In [155]:
phi(cross_squared_distance_matrix(src_pts, dest_pts), 2)

tensor([[9.8875]])

In [156]:
def apply_interpolation():
    pass

In [198]:
test_A = torch.tensor([[[1., 1., 6., 1.],
         [1., 1., 0., 0.],
         [6., 0., 1., 0.],
         [1., 0., 0., 1.]]])
test_B = torch.tensor([[[-1.,  3.],
         [ 0.,  0.],
         [ 0.,  0.],
         [ 0.,  0.]]])
print(test_A.shape, test_B.shape)
X, LU = torch.gesv(test_B, test_A)
X.shape, LU

torch.Size([1, 4, 4]) torch.Size([1, 4, 2])


(torch.Size([1, 4, 2]), tensor([[[ 6.0000,  0.0000,  1.0000,  0.0000],
          [ 0.1667,  1.0000, -0.1667,  0.0000],
          [ 0.1667,  1.0000,  6.0000,  1.0000],
          [ 0.1667,  0.0000, -0.0278,  1.0278]]]))