In [89]:
import os
import math
os.environ['CUDA_VISIBLE_DEVICES'] = '4'

In [2]:
import torch
import torch.nn as nn

In [123]:
class SoftProposal(nn.Module):
    def __init__(self, N, max_iter, err_th, device=torch.device('cuda')):
        super(SoftProposal, self).__init__()
        self.N = N
        self.device = device
        self.dist = self.init_dist().to(device)
        self.max_iter = max_iter
        self.err_th = err_th
    
    def init_dist(self):
        x = torch.arange(end=self.N)
        y = torch.arange(end=self.N)
        xx, yy = torch.meshgrid(x, y)
        grid = torch.stack((xx, yy), dim=0)
        dist = grid.reshape(2, -1, 1) - grid.reshape(2, 1, -1)
        dist = torch.exp(- torch.sum(dist ** 2, dim=0).float() / (2 * ep2))
        return dist
    
    def forward(self, feature):
        '''
        feature: B * K * N * N
        '''
        B, K, N, _ = feature.shape
        x = feature.reshape(-1, K, N * N, 1)
        y = feature.reshape(-1, K, 1, N * N)
        D_ = torch.norm(x - y, dim=1) * self.dist
        D = D_ / torch.sum(D_, dim=1).view(B, -1, 1)
        M = torch.ones(B, N * N, 1, dtype=torch.float, device=self.device) / (N * N)
        last_M = M
        for i in range(self.max_iter):
            M = torch.bmm(D, M)
            if torch.mean(torch.abs(M - last_M)) < self.err_th:
                break
            last_M = M
        return M.view(B, 1, N, N) * feature

In [124]:
s = SoftProposal(7, 10, 1e-5)

In [125]:
feature = torch.randn(4, 2202, 7, 7).cuda()

In [127]:
%time f = s(feature)

CPU times: user 3.65 ms, sys: 0 ns, total: 3.65 ms
Wall time: 2.71 ms


In [100]:
D.shape

torch.Size([4, 49, 49])

In [114]:
a = torch.randn(2, 2)
b = torch.randn(2, 2)
last = a
print(a)
a = torch.matmul(b, a)
print(last)
print(a)

tensor([[ 0.1989,  0.3080],
        [-1.0303,  0.1004]])
tensor([[ 0.1989,  0.3080],
        [-1.0303,  0.1004]])
tensor([[-1.0818, -0.7737],
        [ 0.6590,  0.0599]])


In [111]:
for i in range(7):
    for j in range(7):
        for p in range(7):
            for q in range(7):
                delta_u = torch.norm(feature[:, :, i, j] - feature[:, :, p, q], dim=1)
                delta_L = math.exp(-((i-p)**2 + (j-q)**2)/(2 * ep2))
                D_t = delta_u * delta_L
                D_f = D[:, i * 7 + j, p * 7 + q]
                assert torch.sum(D_t - D_f).item() < 1e-4

In [110]:
torch.sum(D_t - D_f).item()

1.52587890625e-05

In [108]:
D_f

tensor([42.4220, 42.2750, 42.3989, 42.9046], device='cuda:0')

In [46]:
N = 7
K = 2202
B = 4
ep2 = (0.15 * N) ** 2

In [47]:
tmp_feature = torch.randn(B, K, N, N)

In [48]:
x = tmp_feature.reshape(B, K, N * N, 1)

In [49]:
#y = tmp_feature.transpose(2, 3).reshape(B, K, N * N, 1).transpose(2, 3)
y = tmp_feature.reshape(B, K, 1, N * N)

In [69]:
D_ = torch.norm(x-y, dim=1)

In [72]:
torch.sum(D_, dim=2).shape

torch.Size([4, 49])

In [31]:
x = torch.arange(end=N)
y = torch.arange(end=N)
xx, yy = torch.meshgrid(x, y)

In [32]:
grid = torch.stack((xx, yy), dim=0)
dist = grid.reshape(2, -1, 1) - grid.reshape(2, 1, -1)

In [33]:
dist = torch.exp(- torch.sum(dist ** 2, dim=0).float() / (2 * ep2))

In [34]:
dist.shape

torch.Size([25, 25])

In [36]:
(D * dist).shape

torch.Size([4, 25, 25])

In [None]:
class SoftProposal(nn.Module):
    def __init__(self, opt):
        super(SoftProposal, self).__init__()
        self.K = opt['K']
        self.N = opt['N']
        self.dist = self.init_dist()
    
    def init_dist(self):
        x = torch.arange(end=self.N)
        y = torch.arange(end=self.N)
        xx, yy = torch.meshgrid(x, y)
        grid = torch.stack((xx, yy), dim=0)
        dist = grid.reshape(2, -1, 1) - grid.reshape(2, 1, -1)
        dist = torch.exp(- torch.sum(dist ** 2, dim=0).float() / (2 * ep2))
        return dist
    
    def forward(self, feature):
        '''
        feature: B * K * N * N
        '''
        B, K, N, _ = feature.shape
        x = feature.reshape(-1, K, N * N, 1)
        y = feature.reshape(-1, K, 1, N * N)
        D_ = torch.norm(x-y, dim=1) * self.dist
        D = D / torch.sum(D_, dim=-1)
        return D