In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet50, ResNet50_Weights
from torchsummary import summary
import numpy as np
from model import PositionalEncoding
from model import NeighborModel

In [2]:
def get_bou_features(
    img_features: torch.Tensor, boundary: torch.Tensor
) -> torch.Tensor:
    bou_features = img_features[
        0, :, boundary[0, :, 0], boundary[0, :, 1]
    ].unsqueeze(0)
    for i in range(1, boundary.shape[0]):
        bou_features = torch.cat(
            (
                bou_features,
                img_features[
                    i,
                    :,
                    boundary[i, :, 0],
                    boundary[i, :, 1],
                ].unsqueeze(0),
            ),
            dim=0,
        )
    return bou_features

In [3]:
d_token = 1024 + 2
boundary_num = 80
device = "cpu"
res50 = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2).to(device)
res50_bone = nn.Sequential(*list(res50.children())[:-3]).to(device)
positional_encoding = PositionalEncoding(d_token).to(device)
layernorm = nn.LayerNorm(d_token).to(device)
query_encoder = nn.Linear(1024, d_token).to(device)
tranformer_encoder = nn.TransformerEncoder(
    nn.TransformerEncoderLayer(
        d_model=d_token,
        nhead=1,
        batch_first=True,
    ),
    num_layers=1,
).to(device)
xy_offset_encoder = nn.Linear(d_token, 2).to(device)
q_offset_encoder = nn.Linear(d_token, d_token).to(device)
# xy_fc_list = nn.ModuleList()
# for i in range(boundary_num):
#     xy_fc_list.append(nn.Linear(d_token, 2))
# xy_fc_list = xy_fc_list.to(device)
# q_fc_list = nn.ModuleList()
# for i in range(boundary_num):
#     q_fc_list.append(nn.Linear(d_token, d_token))
# q_fc_list = q_fc_list.to(device)



In [4]:
previous_img = torch.rand(3, 3, 224, 224).to(device)
current_img = torch.rand(3, 3, 224, 224).to(device)
boundary = torch.randint(0, 224, (3, boundary_num, 2)).to(device)
boundary

tensor([[[ 90,  45],
         [ 97,  71],
         [  4, 219],
         [  8,  81],
         [159, 138],
         [  9, 117],
         [ 96, 188],
         [147,   2],
         [ 29, 128],
         [126,  93],
         [153,  88],
         [131,  96],
         [ 53, 177],
         [112, 133],
         [117, 212],
         [ 85, 204],
         [ 14, 147],
         [176, 221],
         [ 72, 121],
         [ 90, 174],
         [214, 128],
         [218, 211],
         [185,  13],
         [ 92, 143],
         [197, 176],
         [141, 117],
         [124,  41],
         [ 93,  22],
         [108, 172],
         [141,  74],
         [ 48,  83],
         [135, 217],
         [192, 130],
         [ 85, 181],
         [ 30,  21],
         [ 57,  97],
         [176, 208],
         [128,  24],
         [ 85,  56],
         [173, 115],
         [176,  97],
         [150, 135],
         [204, 152],
         [ 14, 123],
         [ 68, 210],
         [114, 168],
         [  1, 207],
         [ 19

In [5]:
from model import IterativeModel
model = IterativeModel().to(device)

In [6]:
results = model(previous_img, current_img, boundary)

In [12]:
results[0][1].clamp(-1, 1)

tensor([[[ 1.8220e-01, -6.8927e-01],
         [ 1.6962e-01, -7.8950e-01],
         [ 2.0314e-01, -6.2891e-01],
         [ 4.6814e-01, -9.0766e-01],
         [-2.9895e-02, -6.4578e-01],
         [ 8.1367e-02, -5.6777e-01],
         [ 6.7420e-02, -8.0644e-01],
         [-7.3561e-02, -7.2336e-01],
         [-6.5439e-02, -1.0000e+00],
         [ 1.9946e-01, -6.8001e-01],
         [-9.6852e-02, -1.0000e+00],
         [ 1.4019e-01, -9.9435e-01],
         [-9.3980e-02, -8.9267e-01],
         [ 5.2473e-02, -4.4834e-01],
         [ 6.9320e-02, -7.0670e-01],
         [ 3.8935e-01, -6.1003e-01],
         [ 4.6969e-01, -6.7711e-01],
         [ 2.6487e-01, -5.3919e-01],
         [ 1.8914e-01, -1.0000e+00],
         [ 2.3503e-01, -5.1843e-01],
         [ 4.4684e-01, -7.5424e-01],
         [ 1.7829e-01, -3.2491e-01],
         [ 2.2369e-01, -6.0341e-01],
         [ 3.8840e-01, -6.8209e-01],
         [-3.6661e-02, -5.4068e-01],
         [-5.9190e-02, -5.3332e-01],
         [ 1.9813e-02, -2.3972e-01],
 

In [13]:
results[2] is results[0]

False

In [7]:
pre_img_freature = res50_bone(previous_img)
pre_img_freature = F.interpolate(
    pre_img_freature,
    size=(224, 224),
    mode="bilinear",
)
cur_img_freature = res50_bone(current_img)
cur_img_freature = F.interpolate(
    cur_img_freature,
    size=(224, 224),
    mode="bilinear",
)

In [11]:
raw_query_features = get_bou_features(pre_img_freature, boundary)
raw_query_features.shape

torch.Size([3, 1024, 80])

In [12]:
query_features = query_encoder(raw_query_features.permute(0, 2, 1))
query_features.shape

torch.Size([3, 80, 1026])

In [14]:
boundary_features = get_bou_features(cur_img_freature, boundary)
boundary_features = boundary_features.permute(0, 2, 1)
boundary_features.shape

torch.Size([3, 80, 1024])

In [15]:
boundary_tokens = torch.cat([boundary_features, boundary], dim=2)
boundary_tokens.shape

torch.Size([3, 80, 1026])

In [16]:
tokens = torch.cat([query_features, boundary_tokens], dim=1)
tokens.shape

torch.Size([3, 160, 1026])

In [17]:
tokens = layernorm(tokens)
tokens = positional_encoding(tokens)

In [18]:
out_tokens = tranformer_encoder(tokens)
out_tokens.shape

torch.Size([3, 160, 1026])

In [23]:
query_offsets = tokens[:, :boundary_num, :]
query_offsets = q_offset_encoder(query_offsets)
query_offsets.shape

torch.Size([3, 80, 1026])

In [14]:
xy_offsets = tokens[:, boundary_num:, :]
xy_offsets = xy_offset_encoder(xy_offsets)
xy_offsets

NameError: name 'tokens' is not defined