In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet50, ResNet50_Weights
from torchsummary import summary
import numpy as np
from model import PositionalEncoding

In [15]:
def get_bou_features(
    img_features: torch.Tensor, boundary: torch.Tensor
) -> torch.Tensor:
    bou_features = img_features[
        0, :, boundary[0, :, 0], boundary[0, :, 1]
    ].unsqueeze(0)
    for i in range(1, boundary.shape[0]):
        bou_features = torch.cat(
            (
                bou_features,
                img_features[
                    i,
                    :,
                    boundary[i, :, 0],
                    boundary[i, :, 1],
                ].unsqueeze(0),
            ),
            dim=0,
        )
    return bou_features

In [3]:
scale_level = 4
d_token = 1024 + 7 * 7 * scale_level + 2
boundary_num = 80
device = "cpu"
res50 = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2).to(device)
res50_bone = nn.Sequential(*list(res50.children())[:-3]).to(device)
positional_encoding = PositionalEncoding(d_token).to(device)
tranformer_encoder = nn.TransformerEncoder(
    nn.TransformerEncoderLayer(
        d_model=d_token,
        nhead=1,
        batch_first=True,
    ),
    num_layers=1,
).to(device)
fc_list = nn.ModuleList()
for i in range(boundary_num):
    fc_list.append(nn.Linear(d_token, 2))
fc_list = fc_list.to(device)



In [52]:
previous_img = torch.rand(3, 3, 224, 224).to(device)
current_img = torch.rand(3, 3, 224, 224).to(device)
boundary = torch.randint(0, 224, (3, boundary_num, 2)).to(device)
boundary

tensor([[[  8, 124],
         [ 41, 200],
         [206, 130],
         [  3, 118],
         [ 64, 220],
         [  5,  35],
         [221, 161],
         [ 72,  79],
         [112,  97],
         [ 94, 112],
         [169,   2],
         [ 16,  43],
         [137, 222],
         [223,   6],
         [197, 185],
         [ 80, 205],
         [ 97,   0],
         [151, 138],
         [117, 155],
         [ 96, 121],
         [216, 125],
         [  0, 212],
         [ 63, 141],
         [185, 145],
         [ 81, 196],
         [ 76,  73],
         [ 93,  53],
         [168, 124],
         [186, 208],
         [168, 177],
         [  0, 138],
         [165, 141],
         [166,  93],
         [ 12,  56],
         [ 30, 110],
         [209,   5],
         [139, 109],
         [ 38, 184],
         [120, 161],
         [207,  39],
         [117, 210],
         [ 40,  96],
         [ 91,  46],
         [169, 147],
         [199, 207],
         [ 25,   5],
         [216, 197],
         [  9

In [14]:
tmp = boundary[0, 0]
(tmp > 224).sum()

tensor(0)

In [5]:
pre_img_freature = res50_bone(previous_img)
pre_img_freature = F.interpolate(
    pre_img_freature,
    size=(224, 224),
    mode="bilinear",
)
cur_img_freature = res50_bone(current_img)
cur_img_freature = F.interpolate(
    cur_img_freature,
    size=(224, 224),
    mode="bilinear",
)

In [6]:
query_features = get_bou_features(pre_img_freature, boundary)
query_features.shape

torch.Size([3, 1024, 80])

In [22]:
boundary[0, 0]

tensor([123,  31])

In [24]:
cur_img_freature[0, :, boundary[0, :, 0], boundary[0, :, 1]]

tensor([[8.5077e-01, 3.0929e-02, 1.6007e-02,  ..., 0.0000e+00, 0.0000e+00,
         1.1826e-02],
        [4.3627e-01, 8.3692e-01, 1.6735e+00,  ..., 7.7271e-01, 5.6761e-01,
         4.4662e-01],
        [3.8950e-03, 5.5361e-04, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        ...,
        [5.2476e-01, 3.0578e-01, 2.8454e-01,  ..., 6.2086e-03, 1.2706e-01,
         4.7585e-01],
        [1.3739e+00, 1.4901e+00, 1.3674e+00,  ..., 1.5380e+00, 4.2726e-01,
         1.1802e+00],
        [5.0908e-01, 0.0000e+00, 1.9151e-01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]], grad_fn=<IndexBackward0>)

In [45]:
boundary.shape

torch.Size([3, 80, 2])

In [55]:
boundary

tensor([[[  8, 124],
         [ 41, 200],
         [206, 130],
         [  3, 118],
         [ 64, 220],
         [  5,  35],
         [221, 161],
         [ 72,  79],
         [112,  97],
         [ 94, 112],
         [169,   2],
         [ 16,  43],
         [137, 222],
         [223,   6],
         [197, 185],
         [ 80, 205],
         [ 97,   0],
         [151, 138],
         [117, 155],
         [ 96, 121],
         [216, 125],
         [  0, 212],
         [ 63, 141],
         [185, 145],
         [ 81, 196],
         [ 76,  73],
         [ 93,  53],
         [168, 124],
         [186, 208],
         [168, 177],
         [  0, 138],
         [165, 141],
         [166,  93],
         [ 12,  56],
         [ 30, 110],
         [209,   5],
         [139, 109],
         [ 38, 184],
         [120, 161],
         [207,  39],
         [117, 210],
         [ 40,  96],
         [ 91,  46],
         [169, 147],
         [199, 207],
         [ 25,   5],
         [216, 197],
         [  9

In [79]:
boundary//4

tensor([[[ 2, 31],
         [10, 50],
         [51, 32],
         [ 0, 29],
         [16, 55],
         [ 1,  8],
         [55, 40],
         [18, 19],
         [28, 24],
         [23, 28],
         [42,  0],
         [ 4, 10],
         [34, 55],
         [55,  1],
         [49, 46],
         [20, 51],
         [24,  0],
         [37, 34],
         [29, 38],
         [24, 30],
         [54, 31],
         [ 0, 53],
         [15, 35],
         [46, 36],
         [20, 49],
         [19, 18],
         [23, 13],
         [42, 31],
         [46, 52],
         [42, 44],
         [ 0, 34],
         [41, 35],
         [41, 23],
         [ 3, 14],
         [ 7, 27],
         [52,  1],
         [34, 27],
         [ 9, 46],
         [30, 40],
         [51,  9],
         [29, 52],
         [10, 24],
         [22, 11],
         [42, 36],
         [49, 51],
         [ 6,  1],
         [54, 49],
         [ 2, 23],
         [36, 37],
         [14, 31],
         [43,  0],
         [ 4, 10],
         [ 6

In [51]:
boundary[0, 2] = torch.tensor([224, 224])
boundary[0, 2] += 5
tmp_img_feature = F.pad(cur_img_freature, (3, 3, 3, 3), "constant", 0)
tmp_img_feature[0, :, boundary[0, :, 0], boundary[0, :, 1]][:, 2]

tensor([0., 0., 0.,  ..., 0., 0., 0.], grad_fn=<SelectBackward0>)

In [72]:
cur_img_freature.shape

torch.Size([3, 1024, 224, 224])

In [71]:
tmp_img_feature = cur_img_freature
tmp_img_feature = F.avg_pool2d(cur_img_freature, 2, 2)
tmp_img_feature.shape

torch.Size([3, 1024, 112, 112])

In [82]:
def get_neighbor_dot(
    query_features: torch.Tensor,
    boundary: torch.Tensor,
    cur_img_freature: torch.Tensor,
    scale_level: int,
) -> torch.Tensor:
    def get_dot_product(
        query_features: torch.Tensor, key_features: torch.Tensor
    ) -> torch.Tensor:
        return torch.matmul(
            query_features.transpose(1, 2).unsqueeze(2),
            key_features.transpose(1, 2).unsqueeze(2).transpose(2, 3),
        ).squeeze()
    if scale_level > 0:
        cur_img_freature = F.avg_pool2d(cur_img_freature, 2 ** scale_level, 2 ** scale_level)
        boundary = boundary // (2 ** scale_level)
    print(cur_img_freature.shape)

    cur_img_freature = F.pad(cur_img_freature, (3, 3, 3, 3), "constant", 0)
    dot_results = []
    for x_offset in range(-3, 4):
        for y_offset in range(-3, 4):
            dot_results.append(
                get_dot_product(
                    query_features,
                    get_bou_features(
                        cur_img_freature,
                        boundary + torch.tensor([x_offset, y_offset]),
                    ),
                )
            )
    return torch.stack(dot_results, dim=2)


get_neighbor_dot(query_features, boundary, cur_img_freature, 3).shape

torch.Size([3, 1024, 28, 28])


torch.Size([3, 80, 49])