# TensorFlow's `sparse_image_warp` in Pytorch

In [226]:
# first call `%store spectro` in the other notebook to save spectro value here
%store -r spectro

In [227]:
import torch
import numpy as np

In [228]:
# source_control_point_locations: `[batch, num_control_points, 2]` float
#   `Tensor`
# dest_control_point_locations: `[batch, num_control_points, 2]` float
#   `Tensor`

In [229]:
num_dimensions = 2

In [230]:
# paper says 'a random point' so we'll use 1 num_control_point
src_pts, dest_pts = torch.tensor([[[1,3]]]), torch.tensor([[[1,6]]])

In [231]:
spectro.shape

torch.Size([1, 128, 1718])

In [328]:
def sparse_image_warp(img_tensor:torch.Tensor,
                      source_control_point_locations:torch.Tensor,
                      dest_control_point_locations:torch.Tensor,
                      interpolation_order=2,
                      regularization_weight=0.0,
                      num_boundaries_points=0):
    control_point_flows = (dest_control_point_locations - source_control_point_locations)
    
#     clamp_boundaries = num_boundary_points > 0
#     boundary_points_per_edge = num_boundary_points - 1
    batch_size, image_height, image_width = img_tensor.shape
    grid_locations = get_grid_locations(image_height, image_width)
    flattened_grid_locations = torch.tensor(flatten_grid_locations(grid_locations, image_height, image_width))

#     flattened_grid_locations = constant_op.constant(
#         _expand_to_minibatch(flattened_grid_locations, batch_size), image.dtype)

#     if clamp_boundaries:
#       (dest_control_point_locations,
#        control_point_flows) = _add_zero_flow_controls_at_boundary(
#            dest_control_point_locations, control_point_flows, image_height,
#            image_width, boundary_points_per_edge)

    flattened_flows = interpolate_spline(
        dest_control_point_locations,
        control_point_flows,
        flattened_grid_locations,
        interpolation_order,
        regularization_weight)

    dense_flows = create_dense_flows(flattened_flows, batch_size, image_height, image_width)

    warped_image = dense_image_warp(img_tensor, dense_flows)

    return warped_image, dense_flows

In [329]:
def get_grid_locations(image_height, image_width):
    """Wrapper for np.meshgrid."""

    y_range = np.linspace(0, image_height - 1, image_height)
    x_range = np.linspace(0, image_width - 1, image_width)
    y_grid, x_grid = np.meshgrid(y_range, x_range, indexing='ij')
    return np.stack((y_grid, x_grid), -1)

In [330]:
def flatten_grid_locations(grid_locations, image_height, image_width):
    return np.reshape(grid_locations, [image_height * image_width, 2])

In [331]:
def create_dense_flows(flattened_flows, batch_size, image_height, image_width):
    # possibly .view
    return torch.reshape(flattened_flows, [batch_size, image_height, image_width, 2])

## Interpolate Spline

In [332]:
def interpolate_spline(train_points:torch.Tensor, train_values:torch.Tensor, query_points:torch.Tensor, order, regularization_weight=0.0,):
    # First, fit the spline to the observed data.
    w, v = solve_interpolation(train_points, train_values, order, regularization_weight)
    # Then, evaluate the spline at the query locations.
    query_values = apply_interpolation(query_points, train_points, w, v, order)

    return query_values

In [333]:
def solve_interpolation(train_points, train_values, order, regularization_weight):
    b, n, d = train_points.shape
    k = train_values.shape[-1]

    # First, rename variables so that the notation (c, f, w, v, A, B, etc.)
    # follows https://en.wikipedia.org/wiki/Polyharmonic_spline.
    # To account for python style guidelines we use
    # matrix_a for A and matrix_b for B.
    
    c = train_points
    f = train_values.float()
    
    matrix_a = phi(cross_squared_distance_matrix(c,c), order).unsqueeze(0)  # [b, n, n]
#     print('Matrix A', matrix_a, matrix_a.shape)
#     if regularization_weight > 0:
#         batch_identity_matrix = array_ops.expand_dims(
#           linalg_ops.eye(n, dtype=c.dtype), 0)
#         matrix_a += regularization_weight * batch_identity_matrix

    # Append ones to the feature values for the bias term in the linear model.
    ones = torch.ones(1, dtype=src_pts.dtype).view([-1, 1, 1])
    matrix_b = torch.cat((c, ones), 2).float()  # [b, n, d + 1]
#     print('Matrix B', matrix_b, matrix_b.shape)

    # [b, n + d + 1, n]
    left_block = torch.cat((matrix_a, torch.transpose(matrix_b, 2, 1)), 1)
#     print('Left Block', left_block, left_block.shape)

    num_b_cols = matrix_b.shape[2]  # d + 1
#     print('Num_B_Cols', matrix_b.shape)
    #     lhs_zeros = torch.zeros((b, num_b_cols, num_b_cols), dtype=train_points.dtype).float()

    # In Tensorflow, zeros are used here. Pytorch gesv fails with zeros for some reason we don't understand.
    # So instead we use very tiny randn values (variance of one, zero mean) on one side of our multiplication.
    lhs_zeros = torch.randn((b, num_b_cols, num_b_cols)) / 1e10
    right_block = torch.cat((matrix_b, lhs_zeros),
                                   1)  # [b, n + d + 1, d + 1]
#     print('Right Block', right_block, right_block.shape)
    lhs = torch.cat((left_block, right_block),
                           2)  # [b, n + d + 1, n + d + 1]
#     print('LHS', lhs, lhs.shape)

    rhs_zeros = torch.zeros((b, d + 1, k), dtype=train_points.dtype).float()
    rhs = torch.cat((f, rhs_zeros), 1)  # [b, n + d + 1, k]
#     print('RHS', rhs, rhs.shape)

    # Then, solve the linear system and unpack the results.
    X, LU = torch.gesv(rhs, lhs)
    w = X[:, :n, :]
    v = X[:, n:, :]

    return w, v

In [334]:
def cross_squared_distance_matrix(x, y):
    """Pairwise squared distance between two (batch) matrices' rows (2nd dim).
        Computes the pairwise distances between rows of x and rows of y
        Args:
        x: [batch_size, n, d] float `Tensor`
        y: [batch_size, m, d] float `Tensor`
        Returns:
        squared_dists: [batch_size, n, m] float `Tensor`, where
        squared_dists[b,i,j] = ||x[b,i,:] - y[b,j,:]||^2
    """
    x_norm_squared = torch.sum(torch.mul(x, x))
    y_norm_squared = torch.sum(torch.mul(y, y))

    x_y_transpose = torch.matmul(x.squeeze(0), y.squeeze(0).transpose(0,1))
    
    # squared_dists[b,i,j] = ||x_bi - y_bj||^2 = x_bi'x_bi- 2x_bi'x_bj + x_bj'x_bj
    squared_dists = x_norm_squared - 2 * x_y_transpose + y_norm_squared

    return squared_dists.float()

In [335]:
# for [1,3] a [1,6]
cross_squared_distance_matrix(src_pts, dest_pts)

tensor([[9.]])

In [336]:
def phi(r, order):
    """Coordinate-wise nonlinearity used to define the order of the interpolation.
    See https://en.wikipedia.org/wiki/Polyharmonic_spline for the definition.
    Args:
    r: input op
    order: interpolation order
    Returns:
    phi_k evaluated coordinate-wise on r, for k = r
    """
    EPSILON=torch.tensor(1e-10)
    # using EPSILON prevents log(0), sqrt0), etc.
    # sqrt(0) is well-defined, but its gradient is not
    if order == 1:
        r = torch.max(r, EPSILON)
        r = torch.sqrt(r)
        return r
    elif order == 2:
        return 0.5 * r * torch.log(torch.max(r, EPSILON))
    elif order == 4:
        return 0.5 * torch.square(r) * torch.log(torch.max(r, EPSILON))
    elif order % 2 == 0:
        r = torch.max(r, EPSILON)
        return 0.5 * torch.pow(r, 0.5 * order) * torch.log(r)
    else:
        r = torch.max(r, EPSILON)
        return torch.pow(r, 0.5 * order)

In [342]:
def apply_interpolation(query_points, train_points, w, v, order):
    """Apply polyharmonic interpolation model to data.
    Given coefficients w and v for the interpolation model, we evaluate
    interpolated function values at query_points.
    Args:
    query_points: `[b, m, d]` x values to evaluate the interpolation at
    train_points: `[b, n, d]` x values that act as the interpolation centers
                    ( the c variables in the wikipedia article)
    w: `[b, n, k]` weights on each interpolation center
    v: `[b, d, k]` weights on each input dimension
    order: order of the interpolation
    Returns:
    Polyharmonic interpolation evaluated at points defined in query_points.
    """
    query_points = query_points.unsqueeze(0)
    # First, compute the contribution from the rbf term.
#     print(query_points.shape, train_points.shape)
    pairwise_dists = cross_squared_distance_matrix(query_points.float(), train_points.float())
#     print('Pairwise', pairwise_dists)
    phi_pairwise_dists = phi(pairwise_dists, order)
#     print('Pairwise phi', phi_pairwise_dists)

    rbf_term = torch.matmul(phi_pairwise_dists, w)

    # Then, compute the contribution from the linear term.
    # Pad query_points with ones, for the bias term in the linear model.
    ones = torch.ones_like(query_points[..., :1])
    query_points_pad = torch.cat((
      query_points,
      ones
    ), 2).float()
    linear_term = torch.matmul(query_points_pad, v)

    return rbf_term + linear_term


In [343]:
test_A = torch.tensor([[[1., 1., 6., 1.],
         [1., 1., 0., 0.],
         [6., 0., 1., 0.],
         [1., 0., 0., 1.]]])
test_B = torch.tensor([[[-1.,  3.],
         [ 0.,  0.],
         [ 0.,  0.],
         [ 0.,  0.]]])
print(test_A.shape, test_B.shape)
X, LU = torch.gesv(test_B, test_A)
X.shape, LU

torch.Size([1, 4, 4]) torch.Size([1, 4, 2])


(torch.Size([1, 4, 2]), tensor([[[ 6.0000,  0.0000,  1.0000,  0.0000],
          [ 0.1667,  1.0000, -0.1667,  0.0000],
          [ 0.1667,  1.0000,  6.0000,  1.0000],
          [ 0.1667,  0.0000, -0.0278,  1.0278]]]))

In [400]:
sparse_image_warp(spectro, src_pts, dest_pts)

stacked torch.Size([1718, 128, 2])
batched_grid torch.Size([1, 128, 1718, 2])
torch.Size([1, 128, 1718, 2])


NameError: name '_interpolate_bilinear' is not defined

In [399]:
def dense_image_warp(image, flow):
    """Image warping using per-pixel flow vectors.
    Apply a non-linear warp to the image, where the warp is specified by a dense
    flow field of offset vectors that define the correspondences of pixel values
    in the output image back to locations in the  source image. Specifically, the
    pixel value at output[b, j, i, c] is
    images[b, j - flow[b, j, i, 0], i - flow[b, j, i, 1], c].
    The locations specified by this formula do not necessarily map to an int
    index. Therefore, the pixel value is obtained by bilinear
    interpolation of the 4 nearest pixels around
    (b, j - flow[b, j, i, 0], i - flow[b, j, i, 1]). For locations outside
    of the image, we use the nearest pixel values at the image boundary.
    Args:
    image: 4-D float `Tensor` with shape `[batch, height, width, channels]`.
    flow: A 4-D float `Tensor` with shape `[batch, height, width, 2]`.
    name: A name for the operation (optional).
    Note that image and flow can be of type tf.half, tf.float32, or tf.float64,
    and do not necessarily have to be the same type.
    Returns:
    A 4-D float `Tensor` with shape`[batch, height, width, channels]`
    and same type as input image.
    Raises:
    ValueError: if height < 2 or width < 2 or the inputs have the wrong number
    of dimensions.
    """
    image = image.unsqueeze(3) # add a single channel dimension to image tensor
    batch_size, height, width, channels = image.shape

    # The flow is defined on the image grid. Turn the flow into a list of query
    # points in the grid space.
    grid_x, grid_y = torch.meshgrid(
        torch.arange(width), torch.arange(height))
    
    stacked_grid = torch.stack((grid_y, grid_x), dim=2).float()
    print('stacked', stacked_grid.shape)
    
    batched_grid = stacked_grid.unsqueeze(-1).permute(3, 1, 0, 2)
    print('batched_grid', batched_grid.shape)
    
    print(flow.shape)
    query_points_on_grid = batched_grid - flow
    query_points_flattened = torch.reshape(query_points_on_grid,
                                               [batch_size, height * width, 2])
    # Compute values at the query points, then reshape the result back to the
    # image grid.
    ## CONTINUE HERE
    interpolated = interpolate_bilinear(image, query_points_flattened)
    interpolated = torch.reshape(interpolated,
                                     [batch_size, height, width, channels])
    return interpolated