In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet50, ResNet50_Weights
from torchsummary import summary
import numpy as np
from model import PositionalEncoding
from model import NeighborModel

In [2]:
def get_bou_features(
    img_features: torch.Tensor, boundary: torch.Tensor
) -> torch.Tensor:
    bou_features = img_features[
        0, :, boundary[0, :, 0], boundary[0, :, 1]
    ].unsqueeze(0)
    for i in range(1, boundary.shape[0]):
        bou_features = torch.cat(
            (
                bou_features,
                img_features[
                    i,
                    :,
                    boundary[i, :, 0],
                    boundary[i, :, 1],
                ].unsqueeze(0),
            ),
            dim=0,
        )
    return bou_features

In [3]:
scale_level = 4
d_token = 1024 + 7 * 7 * scale_level + 2
boundary_num = 80
device = "cuda"
res50 = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2).to(device)
res50_bone = nn.Sequential(*list(res50.children())[:-3]).to(device)
positional_encoding = PositionalEncoding(d_token).to(device)
tranformer_encoder = nn.TransformerEncoder(
    nn.TransformerEncoderLayer(
        d_model=d_token,
        nhead=1,
        batch_first=True,
    ),
    num_layers=1,
).to(device)
fc_list = nn.ModuleList()
for i in range(boundary_num):
    fc_list.append(nn.Linear(d_token, 2 + 1024))
fc_list = fc_list.to(device)



In [4]:
previous_img = torch.rand(3, 3, 224, 224).to(device)
current_img = torch.rand(3, 3, 224, 224).to(device)
boundary = torch.randint(0, 224, (3, boundary_num, 2)).to(device)
boundary

tensor([[[182,  74],
         [194, 115],
         [142,  60],
         [ 86, 194],
         [ 27, 157],
         [134, 179],
         [198, 114],
         [193, 155],
         [155,  94],
         [223,  25],
         [ 94, 170],
         [168, 111],
         [216, 222],
         [ 22, 150],
         [ 59,  63],
         [ 43, 149],
         [ 11,  47],
         [128, 159],
         [180,  29],
         [139, 121],
         [211,  45],
         [164, 197],
         [ 45, 200],
         [  7, 183],
         [195,  84],
         [ 20, 183],
         [147, 121],
         [130,  36],
         [212,  74],
         [195, 146],
         [ 68,  32],
         [137, 165],
         [140, 150],
         [116, 113],
         [ 46,  52],
         [169,  38],
         [109,  17],
         [ 46, 123],
         [ 87, 130],
         [149,  77],
         [ 15,   3],
         [160, 105],
         [183, 138],
         [159,   9],
         [178, 209],
         [206, 124],
         [138, 181],
         [123

In [5]:
model = NeighborModel()
model = model.to(device)


In [6]:
boundary.shape

torch.Size([3, 80, 2])

In [7]:
results = model(previous_img, current_img, boundary)

In [9]:
len(results)

6

In [11]:
results[0]

tensor([[[182,  74],
         [194, 115],
         [142,  60],
         [ 86, 194],
         [ 27, 157],
         [134, 179],
         [198, 114],
         [193, 155],
         [155,  94],
         [223,  25],
         [ 93, 170],
         [168, 111],
         [216, 222],
         [ 22, 150],
         [ 59,  63],
         [ 43, 149],
         [ 11,  47],
         [128, 159],
         [180,  29],
         [139, 121],
         [211,  45],
         [164, 197],
         [ 45, 200],
         [  7, 183],
         [195,  84],
         [ 20, 183],
         [147, 121],
         [130,  36],
         [212,  74],
         [195, 146],
         [ 68,  32],
         [137, 165],
         [140, 149],
         [116, 113],
         [ 46,  52],
         [169,  38],
         [109,  17],
         [ 46, 123],
         [ 87, 130],
         [149,  77],
         [ 15,   3],
         [160, 105],
         [183, 138],
         [159,   9],
         [178, 209],
         [206, 124],
         [138, 181],
         [123

In [105]:
pre_img_freature = res50_bone(previous_img)
pre_img_freature = F.interpolate(
    pre_img_freature,
    size=(224, 224),
    mode="bilinear",
)
cur_img_freature = res50_bone(current_img)
cur_img_freature = F.interpolate(
    cur_img_freature,
    size=(224, 224),
    mode="bilinear",
)

In [106]:
query_features = get_bou_features(pre_img_freature, boundary)
query_features.shape

torch.Size([3, 1024, 80])

In [107]:
cur_img_freature[0, :, boundary[0, :, 0], boundary[0, :, 1]]

tensor([[2.3984e-01, 7.6721e-02, 4.1830e-01,  ..., 6.2594e-02, 5.1910e-03,
         4.1260e-01],
        [4.2980e-01, 2.5420e-01, 2.8382e-01,  ..., 1.7188e+00, 3.1121e-01,
         4.4717e-01],
        [1.6919e-02, 0.0000e+00, 2.0183e+00,  ..., 4.0063e-01, 9.4191e-01,
         5.7375e-03],
        ...,
        [2.2230e-01, 7.1560e-02, 1.4204e+00,  ..., 2.1656e-01, 1.9849e+00,
         3.7848e-01],
        [1.1351e+00, 3.0416e-01, 3.2139e+00,  ..., 1.1890e+00, 5.7572e-01,
         1.8277e+00],
        [0.0000e+00, 0.0000e+00, 1.3709e-02,  ..., 1.4319e-01, 0.0000e+00,
         5.8843e-04]], device='cuda:0', grad_fn=<IndexBackward0>)

In [108]:
boundary.shape

torch.Size([3, 80, 2])

In [109]:
boundary

tensor([[[124,  41],
         [142, 189],
         [178, 184],
         [ 37,  76],
         [ 43, 207],
         [150,  12],
         [ 52,  45],
         [ 79,  88],
         [ 60, 176],
         [185, 202],
         [ 17,  39],
         [125, 128],
         [207, 219],
         [ 24, 185],
         [126, 132],
         [ 37,  92],
         [105, 115],
         [106, 122],
         [185, 185],
         [ 20,  29],
         [ 14, 181],
         [195, 108],
         [190, 125],
         [  9,  99],
         [ 33, 108],
         [ 87,  75],
         [ 21,  75],
         [ 17, 151],
         [  6, 153],
         [207, 131],
         [217, 129],
         [214,  96],
         [ 25, 202],
         [158, 216],
         [ 18, 177],
         [ 76,  75],
         [ 92, 208],
         [168,  71],
         [187, 177],
         [140, 188],
         [  7, 179],
         [ 60, 179],
         [175,  18],
         [ 11, 138],
         [ 68, 165],
         [188,  18],
         [ 96,  21],
         [198

In [110]:
boundary//4

tensor([[[31, 10],
         [35, 47],
         [44, 46],
         [ 9, 19],
         [10, 51],
         [37,  3],
         [13, 11],
         [19, 22],
         [15, 44],
         [46, 50],
         [ 4,  9],
         [31, 32],
         [51, 54],
         [ 6, 46],
         [31, 33],
         [ 9, 23],
         [26, 28],
         [26, 30],
         [46, 46],
         [ 5,  7],
         [ 3, 45],
         [48, 27],
         [47, 31],
         [ 2, 24],
         [ 8, 27],
         [21, 18],
         [ 5, 18],
         [ 4, 37],
         [ 1, 38],
         [51, 32],
         [54, 32],
         [53, 24],
         [ 6, 50],
         [39, 54],
         [ 4, 44],
         [19, 18],
         [23, 52],
         [42, 17],
         [46, 44],
         [35, 47],
         [ 1, 44],
         [15, 44],
         [43,  4],
         [ 2, 34],
         [17, 41],
         [47,  4],
         [24,  5],
         [49, 48],
         [42, 29],
         [ 8, 26],
         [20, 31],
         [51, 31],
         [18

In [51]:
boundary[0, 2] = torch.tensor([224, 224])
boundary[0, 2] += 5
tmp_img_feature = F.pad(cur_img_freature, (3, 3, 3, 3), "constant", 0)
tmp_img_feature[0, :, boundary[0, :, 0], boundary[0, :, 1]][:, 2]

tensor([0., 0., 0.,  ..., 0., 0., 0.], grad_fn=<SelectBackward0>)

In [72]:
cur_img_freature.shape

torch.Size([3, 1024, 224, 224])

In [111]:
tmp_img_feature = cur_img_freature
tmp_img_feature = F.avg_pool2d(cur_img_freature, 2, 2)
tmp_img_feature.shape

torch.Size([3, 1024, 112, 112])

In [113]:
def get_neighbor_dot(
    query_features: torch.Tensor,
    boundary: torch.Tensor,
    cur_img_freature: torch.Tensor,
    scale_level: int,
) -> torch.Tensor:
    def get_dot_product(
        query_features: torch.Tensor, key_features: torch.Tensor
    ) -> torch.Tensor:
        return torch.matmul(
            query_features.transpose(1, 2).unsqueeze(2),
            key_features.transpose(1, 2).unsqueeze(2).transpose(2, 3),
        ).squeeze()
    if scale_level > 0:
        cur_img_freature = F.avg_pool2d(cur_img_freature, 2 ** scale_level, 2 ** scale_level)
        boundary = boundary // (2 ** scale_level)
    device = query_features.device

    cur_img_freature = F.pad(cur_img_freature, (3, 3, 3, 3), "constant", 0)
    dot_results = []
    for x_offset in range(-3, 4):
        for y_offset in range(-3, 4):
            dot_results.append(
                get_dot_product(
                    query_features,
                    get_bou_features(
                        cur_img_freature,
                        boundary + torch.tensor([x_offset, y_offset]).to(device),
                    ),
                )
            )
    return torch.stack(dot_results, dim=2)


get_neighbor_dot(query_features, boundary, cur_img_freature, 3).shape

torch.Size([3, 80, 49])

In [114]:
dot_patches = []
for i in range(4):
    dot_patches.append(
        get_neighbor_dot(
            query_features,
            boundary,
            cur_img_freature,
            i,
        )
    )
dot_patches = torch.cat(dot_patches, dim=2)
dot_patches.shape

torch.Size([3, 80, 196])

In [115]:
print(dot_patches.shape)
print(query_features.shape)
print(boundary.shape)

torch.Size([3, 80, 196])
torch.Size([3, 1024, 80])
torch.Size([3, 80, 2])


In [116]:
tokens = torch.cat([query_features.transpose(1, 2), dot_patches, boundary.float()], dim=2)
tokens.shape

torch.Size([3, 80, 1222])

In [117]:
out_tokens = tranformer_encoder(positional_encoding(tokens))
out_tokens.shape

torch.Size([3, 80, 1222])

In [118]:
results = []
for i in range(boundary_num):
    results.append(fc_list[i](out_tokens[:, i, :]))
results = torch.stack(results, dim=1)
results.shape

torch.Size([3, 80, 1026])

In [119]:
boundary_offset = results[:, :, :2]
boundary_offset.shape

torch.Size([3, 80, 2])

In [120]:
query_feature_offset = results[:, :, 2:]
query_feature_offset.shape

torch.Size([3, 80, 1024])

In [121]:
query_features.shape

torch.Size([3, 1024, 80])