In [1]:
import torch
import numpy as np
import torch.nn.functional as F

In [4]:
# N * D * H/8 * W/8
H, W = 32, 32
Hc, Wc = H//8, W//8
desc_raw = torch.Tensor(np.random.random((32,256,Hc, Wc)))
warped_desc_raw = desc_raw

mp = 1
mn = 0.2
ld = 250

In [5]:
def descriptor_head(desc_raw):
    desc = F.interpolate(desc_raw, scale_factor=8, mode='bilinear', align_corners=False)
    desc = F.normalize(desc, p=2, dim=1)
    return desc_raw, desc

# should use the raw descriptor output from the network
def descriptor_loss(desc, warped_desc):
    Hc, Wc = desc_raw.shape[-2:]
    p_hw = torch.stack(torch.meshgrid((torch.arange(Hc), torch.arange(Wc))), dim=-1)
    p_hw = p_hw * 8 + 8 // 2
    warped_p_hw = p_hw # should be the warped location for p_hw
    
    p_hw = p_hw.view(1,Hc,Wc,1,1,2).float()
    warped_p_hw = warped_p_hw.view(1,1,1,Hc,Wc,2).float()
    s = torch.le(torch.norm(p_hw-warped_p_hw,p=2,dim=-1), 8).float()

    desc = desc.view((32,Hc,Wc,1,1,-1))
    warped_desc = warped_desc.view((32,1,1,Hc,Wc,-1))
    dot_prod = torch.sum(desc*warped_desc, dim=-1)
    
    loss = ld * s * torch.clamp(mp - dot_prod, min=0.) + (1-s) * torch.clamp(dot_prod - mn, min=0.)
    loss = torch.sum(loss) / (Hc*Wc)**2 # normalization needs to be revisited, after applying homography
    return loss

In [6]:
descriptor_loss(desc_raw, warped_desc_raw)

tensor(1541.3599)