In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet50, ResNet50_Weights
from torchsummary import summary
import numpy as np
from model import PositionalEncoding
from model import NeighborModel

In [2]:
def get_bou_features(
    img_features: torch.Tensor, boundary: torch.Tensor
) -> torch.Tensor:
    bou_features = img_features[
        0, :, boundary[0, :, 0], boundary[0, :, 1]
    ].unsqueeze(0)
    for i in range(1, boundary.shape[0]):
        bou_features = torch.cat(
            (
                bou_features,
                img_features[
                    i,
                    :,
                    boundary[i, :, 0],
                    boundary[i, :, 1],
                ].unsqueeze(0),
            ),
            dim=0,
        )
    return bou_features

In [3]:
d_token = 1024 + 2
boundary_num = 80
device = "cpu"
res50 = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2).to(device)
res50_bone = nn.Sequential(*list(res50.children())[:-3]).to(device)
positional_encoding = PositionalEncoding(d_token).to(device)
layernorm = nn.LayerNorm(d_token).to(device)
query_encoder = nn.Linear(1024, d_token).to(device)
tranformer_encoder = nn.TransformerEncoder(
    nn.TransformerEncoderLayer(
        d_model=d_token,
        nhead=1,
        batch_first=True,
    ),
    num_layers=1,
).to(device)
xy_offset_encoder = nn.Linear(d_token, 2).to(device)
q_offset_encoder = nn.Linear(d_token, d_token).to(device)
# xy_fc_list = nn.ModuleList()
# for i in range(boundary_num):
#     xy_fc_list.append(nn.Linear(d_token, 2))
# xy_fc_list = xy_fc_list.to(device)
# q_fc_list = nn.ModuleList()
# for i in range(boundary_num):
#     q_fc_list.append(nn.Linear(d_token, d_token))
# q_fc_list = q_fc_list.to(device)



In [4]:
first_img = torch.rand(3, 3, 224, 224).to(device)
previous_img = torch.rand(3, 3, 224, 224).to(device)
current_img = torch.rand(3, 3, 224, 224).to(device)
boundary = torch.randint(0, 224, (3, boundary_num, 2)).to(device)
boundary.shape

torch.Size([3, 80, 2])

In [7]:
first_img_feature = res50_bone(first_img)
first_img_feature = F.interpolate(
    first_img_feature,
    size=(224, 224),
    mode="bilinear",
)
pre_img_feature = res50_bone(previous_img)
pre_img_feature = F.interpolate(
    pre_img_feature,
    size=(224, 224),
    mode="bilinear",
)
cur_img_feature = res50_bone(current_img)
cur_img_feature = F.interpolate(
    cur_img_feature,
    size=(224, 224),
    mode="bilinear",
)
first_img_feature.shape

torch.Size([3, 1024, 224, 224])

In [15]:
def get_neighbor_features(
    img_features: torch.Tensor, 
    boundary: torch.Tensor,
    scale_level: int
): 
    device = img_features.device
    if scale_level > 0:
        img_features = F.avg_pool2d(
            img_features, 
            2 ** scale_level,
            2 ** scale_level,
        )
        boundary = boundary // (2 ** scale_level)
    img_features = F.pad(img_features, (1, 1, 1, 1), "constant", 0)
    neighor_features = []
    for x_offset in range(-1, 2):
        for y_offset in range(-1, 2):
            neighor_features.append(
                get_bou_features(
                    img_features,
                    boundary + torch.tensor([x_offset, y_offset]).to(device),
                )
            )
    return torch.cat(neighor_features, dim=1)
def get_neighbor_features_with_scales(
    img_features: torch.Tensor, 
    boundary: torch.Tensor,
    scale_levels: list[int]
):
    neighbor_features = get_neighbor_features(img_features, boundary, scale_levels[0])
    for scale_level in scale_levels[1:]:
        neighbor_features = torch.cat(
            (
                neighbor_features,
                get_neighbor_features(img_features, boundary, scale_level),
            ),
            dim=1,
        )
    return neighbor_features
get_neighbor_features_with_scales(first_img_feature, boundary, [0, 2]).shape

torch.Size([3, 18432, 80])

In [39]:
first_query = get_bou_features(first_img_feature, boundary)
first_query = first_query.permute(0, 2, 1)
pre_query = get_bou_features(pre_img_feature, boundary)
pre_query = pre_query.permute(0, 2, 1)
first_query.shape


torch.Size([3, 80, 1024])

In [41]:
def get_best_match(feature0: torch.Tensor, feature1: torch.Tensor):
    best_shift = 0
    best_similarity = (feature0 * feature1).sum() 
    for shift in range(feature0.shape[1]):
        similarity = (feature0 * feature1.roll(shift, dims=1)).sum()
        if similarity > best_similarity:
            best_similarity = similarity
            best_shift = shift
    return feature0, feature1.roll(best_shift, dims=1)
    
get_best_match(first_query, pre_query)

tensor(98915.5312, grad_fn=<SumBackward0>)
tensor(98915.5312, grad_fn=<SumBackward0>)
tensor(90949.9922, grad_fn=<SumBackward0>)
tensor(90598.2734, grad_fn=<SumBackward0>)
tensor(90960.5625, grad_fn=<SumBackward0>)
tensor(90725.2656, grad_fn=<SumBackward0>)
tensor(90631.5469, grad_fn=<SumBackward0>)
tensor(90929.3281, grad_fn=<SumBackward0>)
tensor(91280.0703, grad_fn=<SumBackward0>)
tensor(90991.5234, grad_fn=<SumBackward0>)
tensor(91189.1953, grad_fn=<SumBackward0>)
tensor(90458.0156, grad_fn=<SumBackward0>)
tensor(90843.2188, grad_fn=<SumBackward0>)
tensor(90949.4531, grad_fn=<SumBackward0>)
tensor(91146.4219, grad_fn=<SumBackward0>)
tensor(90479.8516, grad_fn=<SumBackward0>)
tensor(90386.1484, grad_fn=<SumBackward0>)
tensor(91073.7031, grad_fn=<SumBackward0>)
tensor(90640.0312, grad_fn=<SumBackward0>)
tensor(91199.4062, grad_fn=<SumBackward0>)
tensor(90161.0078, grad_fn=<SumBackward0>)
tensor(91157.1016, grad_fn=<SumBackward0>)
tensor(90881.9062, grad_fn=<SumBackward0>)
tensor(9119

In [11]:
raw_query_features = get_bou_features(pre_img_feature, boundary)
raw_query_features.shape

torch.Size([3, 1024, 80])

In [12]:
query_features = query_encoder(raw_query_features.permute(0, 2, 1))
query_features.shape

torch.Size([3, 80, 1026])

In [14]:
boundary_features = get_bou_features(cur_img_feature, boundary)
boundary_features = boundary_features.permute(0, 2, 1)
boundary_features.shape

torch.Size([3, 80, 1024])

In [15]:
boundary_tokens = torch.cat([boundary_features, boundary], dim=2)
boundary_tokens.shape

torch.Size([3, 80, 1026])

In [16]:
tokens = torch.cat([query_features, boundary_tokens], dim=1)
tokens.shape

torch.Size([3, 160, 1026])

In [17]:
tokens = layernorm(tokens)
tokens = positional_encoding(tokens)

In [18]:
out_tokens = tranformer_encoder(tokens)
out_tokens.shape

torch.Size([3, 160, 1026])

In [23]:
query_offsets = tokens[:, :boundary_num, :]
query_offsets = q_offset_encoder(query_offsets)
query_offsets.shape

torch.Size([3, 80, 1026])

In [14]:
xy_offsets = tokens[:, boundary_num:, :]
xy_offsets = xy_offset_encoder(xy_offsets)
xy_offsets

NameError: name 'tokens' is not defined