In [47]:
from cudants.utils.imageutils import scaling_and_squaring, image_gradient
import torch

In [31]:
I = torch.meshgrid(torch.arange(100), torch.arange(100), torch.arange(100), indexing='ij')
I = torch.stack(I, dim=-1).sum(dim=-1)[None, None].float() / 100

In [32]:
J = torch.meshgrid(100 - torch.arange(100), torch.arange(100), torch.arange(100), indexing='ij')
J = torch.stack(J, dim=-1).sum(dim=-1)[None, None].float() / 100

In [37]:
v = torch.randn(1, 100, 100, 100, 3) * 0.01

In [38]:
from torch.nn import functional as F
grid = F.affine_grid(torch.eye(3, 4)[None], size=I.shape)

  "Default grid_sample and affine_grid behavior has changed "


In [39]:
u = scaling_and_squaring(v, grid)

## Method 1 - Autodiff

In [40]:
warp = grid + u

In [41]:
warp.requires_grad_(True)
loss = F.mse_loss(F.grid_sample(I, warp, align_corners=True), J)
loss.backward()

In [42]:
gradval = warp.grad + 0

## Method 2 - Manual

In [64]:
gradimg = F.grid_sample(I, warp, align_corners=True) - J
gradjac = torch.zeros(1, 100, 100, 100, 3, 3)
for i in range(3):
    warpi = warp[:, None, ..., i].cuda()
    gradwarpi = image_gradient(warpi).cpu()
    gradjac[..., i] = gradwarpi.permute(0, 2, 3, 4, 1)

In [65]:
detjac = torch.linalg.det(gradjac).abs()

In [80]:
detjac.max()

tensor(0.0003, grad_fn=<MaxBackward1>)

In [67]:
gradI = image_gradient(F.grid_sample(I, warp, align_corners=True).cuda()).cpu()

In [69]:
gradimg.shape, gradI.shape, detjac.shape

(torch.Size([1, 1, 100, 100, 100]),
 torch.Size([1, 3, 100, 100, 100]),
 torch.Size([1, 100, 100, 100]))

In [70]:
grad = gradimg * gradI * detjac[:, None]

In [72]:
grad = grad.permute(0, 2, 3, 4, 1)

In [78]:
gradval.max()

tensor(8.8490e-05)