# DVF choosing

In [25]:
import torch
import torch.nn as nn
from torch.functional import F

# Assuming you have tensors A and points
A = torch.rand(2, 256, 256)
points = torch.rand(2, 20)*256  # Replace x1, y1, x2, y2, ..., xN, yN with your actual coordinates


# Reshape tensor points to have dimensions [2, N]
points = points.t().long()

# Use torch.gather to select values from A using indices from points
result = A[:, points[:, 0], points[:, 1]]

print(result)


tensor([[0.8639, 0.4092, 0.0490, 0.8620, 0.9790, 0.0482, 0.6686, 0.5067, 0.8228,
         0.6523, 0.1760, 0.2323, 0.3223, 0.5444, 0.8475, 0.4863, 0.2013, 0.9656,
         0.8200, 0.1296],
        [0.6927, 0.7752, 0.8152, 0.9625, 0.0389, 0.2407, 0.5616, 0.7734, 0.8034,
         0.8366, 0.2767, 0.0496, 0.4931, 0.2212, 0.9971, 0.8580, 0.0990, 0.9524,
         0.0837, 0.8154]])


In [26]:
result.shape

torch.Size([2, 20])

# Affine transformation

## Method 1

In [27]:
class AffineTransform(nn.Module):
    def __init__(self):
        super(AffineTransform, self).__init__()

    def forward(self, points, matrix):
        points = points.T
        # Add a row of ones to the input points for the affine transformation
        ones = torch.ones(1, points.size(1), dtype=points.dtype, device=points.device)
        points_homogeneous = torch.cat([points, ones], dim=0)

        # Apply the affine transformation
        # print dtype of matrix and points_homogeneous
        transformed_points = torch.mm(matrix, points_homogeneous.float())

        return transformed_points[:2, :]

In [28]:
affine_layer = AffineTransform()

In [29]:
affine_param = torch.tensor([[1.0, 0.0, 0.0],
                                [0.0, 1.0, 0.0]])

In [30]:
transformed_points = affine_layer(points, affine_param)
transformed_points.T

tensor([[141., 175.],
        [196.,  75.],
        [249.,   0.],
        [  9.,  36.],
        [142.,  90.],
        [ 64.,   0.],
        [136., 213.],
        [ 48., 122.],
        [ 76., 243.],
        [148., 242.],
        [ 53., 208.],
        [231.,  69.],
        [ 10., 200.],
        [  0., 148.],
        [ 76., 161.],
        [151., 189.],
        [129., 164.],
        [ 95., 168.],
        [201.,  10.],
        [167.,  97.]])

## Method 2

In [31]:
def transform_to_displacement_field(tensor, tensor_transform, device='cpu'):
    """
    Transforms a tensor using an affine transformation matrix and returns the corresponding displacement field.

    Args:
        tensor (torch.Tensor): The input tensor to transform, with shape (batch_size, channels, height, width).
        tensor_transform (torch.Tensor): The affine transformation matrix, with shape (batch_size, 2, 3).
        device (str, optional): The device to use for the computation (default: 'cpu').

    Returns:
        torch.Tensor: The displacement field tensor, with shape (2, height, width).
    """
    # function code here
    y_size, x_size = tensor.size(2), tensor.size(3)
    deformation_field = F.affine_grid(tensor_transform, tensor.size(), align_corners=False)
    gy, gx = torch.meshgrid(torch.arange(y_size), torch.arange(x_size))
    gy = gy.type(torch.FloatTensor).to(device)
    gx = gx.type(torch.FloatTensor).to(device)
    grid_x = (gx / (x_size - 1) - 0.5) * 2
    grid_y = (gy / (y_size - 1) - 0.5) * 2
    u_x = deformation_field[0, :, :, 0] - grid_x
    u_y = deformation_field[0, :, :, 1] - grid_y
    u_x = u_x / 2 * (x_size - 1)
    u_y = u_y / 2 * (y_size - 1)
    displacement_field = torch.cat((u_x.view(1, y_size, x_size), u_y.view(1, y_size, x_size)), dim=0)
    return displacement_field

In [32]:
def transform_points_DVF(points, M, image):
    # transform points using displacement field
    # DVF.shape = (2, H, W)
    # points.shape = (2, N)
    displacement_field = torch.zeros(image.shape[-1], image.shape[-1])
    DVF = transform_to_displacement_field(
        displacement_field.view(1, 1, displacement_field.size(0), displacement_field.size(1)), 
        M.view(1, 2, 3))

    # Reshape tensor points to have dimensions [2, N]
    points = points.t().long()

    # Use torch.gather to select values from A using indices from points
    result = DVF[:, points[:, 0], points[:, 1]]

    # Reshape result to have dimensions [2, N]
    result = result.t()
    # subtract the result from the original points
    points = points.float()
    result = torch.subtract(points, result)
    return result

In [33]:
transformed_points = transform_points_DVF(points.T, affine_param, A)
transformed_points

tensor([[ 1.4119e+02,  1.7505e+02],
        [ 1.9579e+02,  7.5268e+01],
        [ 2.4850e+02,  4.7461e-01],
        [ 8.6426e+00,  3.5537e+01],
        [ 1.4185e+02,  9.0057e+01],
        [ 6.3502e+01, -2.4805e-01],
        [ 1.3633e+02,  2.1303e+02],
        [ 4.7979e+01,  1.2169e+02],
        [ 7.6451e+01,  2.4280e+02],
        [ 1.4845e+02,  2.4208e+02],
        [ 5.3314e+01,  2.0771e+02],
        [ 2.3077e+02,  6.9404e+01],
        [ 1.0283e+01,  1.9954e+02],
        [ 8.0092e-02,  1.4750e+02],
        [ 7.6131e+01,  1.6080e+02],
        [ 1.5124e+02,  1.8909e+02],
        [ 1.2914e+02,  1.6401e+02],
        [ 9.5158e+01,  1.6787e+02],
        [ 2.0054e+02,  1.0287e+01],
        [ 1.6688e+02,  9.7154e+01]])