In [40]:
import torch
import torch.nn as nn
# 根据热图（Heatmap）采样关键点（keypoints）TODO
class KeypointSampler(nn.Module):
  '''
  Sample keypoints according to a Heatmap
  Input
    x: [B, 1, H, W] heatmap

  Returns
    [list]:
      kps: [N, 2] - keypoint positions
      log_probs: [N] - logprobs for each kp
  '''  
  def __init__(self, window_size = 8): 
    super().__init__()
    self.window_size = window_size

  # 将输入的热图张量进行划分，划分成一个个的窗口。
  def gridify(self, x):
    B, C, H, W = x.shape
    x = x.unfold(2, self.window_size, self.window_size)                              \
          .unfold(3, self.window_size, self.window_size)                             \
          .reshape(B, C, H//self.window_size, W//self.window_size, self.window_size**2)

    return x

  def sample(self, grid):
    '''
    Sample keypoints given a grid where each cell has logits stacked in last dimension
    Input
      grid: [B, C, H//w, W//w, w*w]

    Returns
      log_probs: [B, C, H//w, W//w ] - logprobs of selected samples
      choices: [B, C, H//w, W//w] indices of choices
      accept_mask: [B, C, H//w, W//w] mask of accepted keypoints

    '''
    chooser = torch.distributions.Categorical(logits = grid)
    choices = chooser.sample()
    selected_choices = torch.gather(grid, -1, choices.unsqueeze(-1)).squeeze(-1)

    flipper = torch.distributions.Bernoulli(logits = selected_choices)
    accepted_choices = flipper.sample()

    #Sum log-probabilities is equivalent to multiplying the probabilities
    log_probs = chooser.log_prob(choices) + flipper.log_prob(accepted_choices)
    print(accepted_choices)
    accept_mask = accepted_choices.gt(0)
    print(accept_mask)
    return log_probs.squeeze(1), choices, accept_mask.squeeze(1)


  def forward(self, x):
    B, C, H, W = x.shape
    keypoint_cells = self.gridify(x)
    idx_cells = self.gridify( torch.dstack(torch.meshgrid(torch.arange(x.shape[-2], dtype=torch.float32),
                                                          torch.arange(x.shape[-1], dtype=torch.float32),
                                                          #indexing='ij'))     \
                                                                        ))     \
                                                         .permute(2,0,1).unsqueeze(0) 
                                                         .expand(B,-1,-1,-1) ).to(x.device)
                                                         

    log_probs, idx, mask = self.sample(keypoint_cells)
    

    keypoints = torch.gather(idx_cells, -1, idx.repeat(1,2,1,1).unsqueeze(-1)).squeeze(-1).permute(0,2,3,1)
    
    xy_probs = [  {'xy':keypoints[b][mask[b]].flip(-1), 'logprobs':log_probs[b][mask[b]]}
                  for b in range(B) ]

    return xy_probs

key1 = KeypointSampler(2)
tensor = torch.rand(1, 1, 4, 6)
xy_p = key1(tensor)

tensor([[[[0., 1., 1.],
          [1., 1., 1.]]]])
tensor([[[[False,  True,  True],
          [ True,  True,  True]]]])


In [36]:
import torch
import torch.nn.functional as F 
x = torch.rand(6,3)
y = torch.rand(6,3)
T = 1.


Dmat = 2. - torch.cdist(x.unsqueeze(0), y.unsqueeze(0)).squeeze(0)
print(Dmat.shape)
logprob_rows = F.log_softmax(Dmat * T, dim=1)
print(logprob_rows)
logprob_cols = F.log_softmax(Dmat.t() * T, dim=1)
choice_rows = torch.argmax(logprob_rows, dim=1)
print(choice_rows)
choice_cols = torch.argmax(logprob_cols, dim=1)

seq = torch.arange(choice_cols.shape[0], dtype = choice_cols.dtype, device = choice_cols.device)
mutual = choice_rows[choice_cols] == seq

logprob_rows = torch.gather(logprob_rows, -1, choice_rows.unsqueeze(-1)).squeeze(-1)
print(logprob_rows)
logprob_cols = torch.gather(logprob_cols, -1, choice_cols.unsqueeze(-1)).squeeze(-1)

log_probs = logprob_rows[choice_cols[mutual]] + logprob_cols[seq[mutual]]

dmatches = torch.cat((choice_cols[mutual].unsqueeze(-1), seq[mutual].unsqueeze(-1)), dim=1)


torch.Size([6, 6])
tensor([[-1.8381, -1.4853, -1.9711, -2.2752, -1.7852, -1.5867],
        [-1.7578, -1.7432, -1.8685, -1.9225, -1.6718, -1.8072],
        [-1.6294, -2.3135, -1.6149, -1.6106, -1.7604, -2.0069],
        [-1.6025, -2.2470, -1.4916, -1.8455, -1.7750, -1.9629],
        [-1.7107, -1.9723, -1.7986, -1.7264, -1.6720, -1.9053],
        [-1.7509, -2.1092, -2.1430, -1.5009, -1.5810, -1.8385]])
tensor([1, 4, 3, 2, 4, 3])
tensor([-1.4853, -1.6718, -1.6106, -1.4916, -1.6720, -1.5009])


In [37]:

print(Dmat)
print(logprob_rows)
print(dmatches )
print(log_probs)

tensor([[1.4079, 1.7607, 1.2749, 0.9709, 1.4608, 1.6594],
        [1.3955, 1.4101, 1.2849, 1.2308, 1.4815, 1.3461],
        [1.3891, 0.7050, 1.4036, 1.4079, 1.2581, 1.0116],
        [1.5784, 0.9338, 1.6892, 1.3354, 1.4058, 1.2179],
        [1.3964, 1.1349, 1.3086, 1.3808, 1.4352, 1.2019],
        [1.4382, 1.0798, 1.0461, 1.6881, 1.6080, 1.3506]])
tensor([-1.4853, -1.6718, -1.6106, -1.4916, -1.6720, -1.5009])
tensor([[0, 1],
        [3, 2],
        [5, 3]])
tensor([-2.7467, -2.9476, -2.9630])


In [43]:
in_channels=2
nparam = 4
attn = nn.Sequential(
                    nn.Linear(in_channels*2, in_channels*4),
                    nn.BatchNorm1d(in_channels*4, affine = False),
                    nn.ReLU(),
                    nn.Dropout(0.1),
                    nn.Linear(in_channels*4, in_channels*4),
                    #nn.BatchNorm1d(in_channels*4, affine = False),
                    nn.ReLU(),
                    nn.Dropout(0.1),
                    nn.Linear(in_channels*4, nparam*2),
                    nn.Tanh(),
                    )
print(attn[0])
#zero-out layer params for initial identity TPS transform
for i in [-2, -5, -9]:
    attn[i].weight.data.normal_(0., 1e-5) 
    attn[i].bias.data.zero_()#normal_(0., 1e-5)

Linear(in_features=4, out_features=8, bias=True)


In [1]:
import cv2
import numpy as np
import matplotlib as plt
class TPS:       
    @staticmethod
    def fit(c, lambd=0., reduced=False):       
        n = c.shape[0]

        U = TPS.u(TPS.d(c, c))
        K = U + np.eye(n, dtype=np.float32)*lambd

        P = np.ones((n, 3), dtype=np.float32)
        P[:, 1:] = c[:, :2]
        v = np.zeros(n+3, dtype=np.float32)
        v[:n] = c[:, -1]

        A = np.zeros((n+3, n+3), dtype=np.float32)
        A[:n, :n] = K
        A[:n, -3:] = P
        A[-3:, :n] = P.T
        theta = np.linalg.solve(A, v) # p has structure w,a
        return theta[1:] if reduced else thete      

    @staticmethod
    def z(x, c, theta):
        x = np.atleast_2d(x)
        U = TPS.u(TPS.d(x, c))
        w, a = theta[:-3], theta[-3:]
        reduced = theta.shape[0] == c.shape[0] + 2
        if reduced:
            w = np.concatenate((-np.sum(w, keepdims=True), w))
        b = np.dot(U, w)
        return a[0] + a[1]*x[:, 0] + a[2]*x[:, 1] + b



def show_warped(img, warped, c_src, c_dst):
    fig, axs = plt.subplots(1, 2, figsize=(16,8))
    axs[0].axis('off')
    axs[1].axis('off')
    axs[0].imshow(img[...,::-1], origin='upper')
    axs[0].scatter(c_src[:, 0]*img.shape[1], c_src[:, 1]*img.shape[0], marker='^', color='black')
    axs[1].imshow(warped[...,::-1], origin='upper')
    axs[1].scatter(c_dst[:, 0]*warped.shape[1], c_dst[:, 1]*warped.shape[0], marker='^', color='black')
    plt.show()

def warp_image_cv(img, c_src, c_dst, dshape=None):
    dshape = dshape or img.shape
    theta = tps.tps_theta_from_points(c_src, c_dst, reduced=True)
    grid = tps.tps_grid(theta, c_dst, dshape)
    mapx, mapy = tps.tps_grid_to_remap(grid, img.shape)
    return cv2.remap(img, mapx, mapy, cv2.INTER_CUBIC)

img = cv2.imread('test.jpg')

c_src = np.array([
    [0.44, 0.18],
    [0.55, 0.18],
    [0.33, 0.23],
    [0.66, 0.23],
    [0.32, 0.79],
    [0.67, 0.80],
])

c_dst = np.array([
    [0.693, 0.466],
    [0.808, 0.466],
    [0.572, 0.524],
    [0.923, 0.524],
    [0.545, 0.965],
    [0.954, 0.966],
])


warped_front = warp_image_cv(img, c_src, c_dst, dshape=(512, 512))
show_warped(img, warped1, c_src_front, c_dst_front)


[ WARN:0@0.071] global loadsave.cpp:248 findDecoder imread_('test.jpg'): can't open/read file: check file path/integrity


NameError: name 'tps' is not defined

In [9]:
import torch
import torch.nn.functional as F
batchSize = 10
gridSize = (32,32)

ident = torch.tensor([[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]], device = 'cpu').expand(batchSize, -1, -1)
grid = F.affine_grid(ident, (batchSize, 1) + gridSize, align_corners= False)
grid_y = grid[..., 0].view(batchSize , -1)
grid_x = grid[..., 1].view(batchSize , -1)

In [11]:
print(ident.shape)
print(grid.shape)
print(grid_x.shape)

torch.Size([10, 2, 3])
torch.Size([10, 32, 32, 2])
torch.Size([10, 1024])


In [8]:
(3,2)+(32,32)

(3, 2, 32, 32)