In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet50, ResNet50_Weights
from torchsummary import summary
import numpy as np
from model import PositionalEncoding
from model import IterWholeFirst
from einops import rearrange

In [11]:
!nvidia-smi

Fri Mar 29 12:55:34 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.14              Driver Version: 550.54.14      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla V100-SXM2-32GB           On  |   00000000:15:00.0 Off |                    0 |
| N/A   57C    P0            285W /  300W |   10919MiB /  32768MiB |     79%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  Tesla V100-SXM2-32GB           On  |   00

In [12]:
def get_bou_features(
    img_features: torch.Tensor, boundary: torch.Tensor
) -> torch.Tensor:
    bou_features = img_features[
        0, :, boundary[0, :, 0], boundary[0, :, 1]
    ].unsqueeze(0)
    for i in range(1, boundary.shape[0]):
        bou_features = torch.cat(
            (
                bou_features,
                img_features[
                    i,
                    :,
                    boundary[i, :, 0],
                    boundary[i, :, 1],
                ].unsqueeze(0),
            ),
            dim=0,
        )
    return bou_features

In [13]:
# d_token = 1024 + 2
boundary_num = 80
device = "cuda"
res50 = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2).to(device)
res50_bone = nn.Sequential(*list(res50.children())[:-3]).to(device)
# positional_encoding = PositionalEncoding(d_token).to(device)
# layernorm = nn.LayerNorm(d_token).to(device)
# query_encoder = nn.Linear(1024, d_token).to(device)
# tranformer_encoder = nn.TransformerEncoder(
#     nn.TransformerEncoderLayer(
#         d_model=d_token,
#         nhead=1,
#         batch_first=True,
#     ),
#     num_layers=1,
# ).to(device)
# xy_offset_encoder = nn.Linear(d_token, 2).to(device)
# q_offset_encoder = nn.Linear(d_token, d_token).to(device)
# xy_fc_list = nn.ModuleList()
# for i in range(boundary_num):
#     xy_fc_list.append(nn.Linear(d_token, 2))
# xy_fc_list = xy_fc_list.to(device)
# q_fc_list = nn.ModuleList()
# for i in range(boundary_num):
#     q_fc_list.append(nn.Linear(d_token, d_token))
# q_fc_list = q_fc_list.to(device)

In [14]:
first_img = torch.rand(3, 3, 224, 224).to(device)
previous_img = torch.rand(3, 3, 224, 224).to(device)
current_img = torch.rand(3, 3, 224, 224).to(device)
boundary = torch.randint(0, 224, (3, boundary_num, 2)).to(device)
boundary.shape

torch.Size([3, 80, 2])

In [15]:
from model import IterWholeFirst_Con

In [16]:
model = IterWholeFirst_Con().to(device)

In [17]:
results = model(first_img, boundary, previous_img, current_img, boundary)

In [18]:
results[0]

tensor([[[218.0502,  73.3597],
         [185.1082,  46.1478],
         [ 67.2036, 179.0918],
         [ 86.9212,  99.9083],
         [ 36.2933, 122.3697],
         [121.7536, 139.1214],
         [193.4686,  34.7747],
         [ 67.9449,  84.7661],
         [117.9955,  90.3885],
         [ 49.5312, 212.6399],
         [161.0197,  28.1296],
         [116.2467, 193.0450],
         [ 55.1245, 119.8204],
         [ 85.9165,   0.6317],
         [ 72.8476, 104.9932],
         [122.7785,  72.9933],
         [ 90.9525, 160.9684],
         [ 32.0098, 152.8551],
         [ 28.3327, 201.2262],
         [130.8240,  37.7677],
         [ 67.1918, 185.3790],
         [ 46.8778,  20.7873],
         [185.9105, 199.4985],
         [ 99.4514, 223.1400],
         [155.0369, 149.6918],
         [121.1199, 178.3872],
         [155.8786, 114.0323],
         [ 75.5257, 136.5029],
         [119.4035,  72.2576],
         [120.1593,  29.8580],
         [ 70.2636,  76.7981],
         [191.9312, 160.1306],
        

In [6]:
first_img_raw_feature = res50_bone(first_img)
first_img_feature = F.interpolate(
    first_img_raw_feature,
    size=(224, 224),
    mode="bilinear",
)
pre_img_raw_feature = res50_bone(previous_img)
pre_img_feature = F.interpolate(
    pre_img_raw_feature,
    size=(224, 224),
    mode="bilinear",
)
cur_img_raw_feature = res50_bone(current_img)
cur_img_feature = F.interpolate(
    cur_img_raw_feature,
    size=(224, 224),
    mode="bilinear",
)
first_img_feature.shape, first_img_raw_feature.shape

(torch.Size([3, 1024, 224, 224]), torch.Size([3, 1024, 14, 14]))

In [41]:
def get_img_tokens(img_features: torch.Tensor) -> torch.Tensor:
    img_tokens = rearrange(
        img_features,
        "b c h w -> b (h w) c",
    )
    return img_tokens
get_img_tokens(first_img_raw_feature).shape

torch.Size([3, 196, 1024])

In [67]:
# first_img_tokens = get_img_tokens(first_img_raw_feature)
# pre_img_tokens = get_img_tokens(pre_img_raw_feature)
cur_img_tokens = get_img_tokens(cur_img_raw_feature)
cur_img_tokens.shape

torch.Size([3, 196, 1024])

In [63]:
d_token = 1024 + 2
positional_encoding = PositionalEncoding(d_token).to(device)
layernorm = nn.LayerNorm(d_token).to(device)
encoder_layer = nn.TransformerEncoderLayer(
    d_model=d_token,
    nhead=1, 
    batch_first=True
).to(device)
tranformer_encoder = nn.TransformerEncoder(
    encoder_layer,
    num_layers=1,
).to(device)



In [44]:
img_token_encoder = nn.Linear(1024, d_token).to(device)

In [68]:
cur_img_tokens = img_token_encoder(cur_img_tokens)
cur_img_tokens = layernorm(cur_img_tokens)
cur_img_tokens = positional_encoding(cur_img_tokens)
memory = tranformer_encoder(cur_img_tokens)
memory.shape

torch.Size([3, 196, 1026])

In [46]:
first_query = get_bou_features(first_img_feature, boundary)
first_query = first_query.permute(0, 2, 1)
pre_query = get_bou_features(pre_img_feature, boundary)
pre_query = pre_query.permute(0, 2, 1)
first_query.shape


torch.Size([3, 80, 1024])

In [49]:
query_encoder = nn.Linear(1024, d_token).to(device)

In [57]:
pre_query_tokens = query_encoder(pre_query)
pre_query_tokens = layernorm(pre_query_tokens)
pre_query_tokens = positional_encoding(pre_query_tokens)
pre_query_tokens.shape

torch.Size([3, 80, 1026])

In [58]:
first_query_tokens = query_encoder(first_query)
first_query_tokens = layernorm(first_query_tokens)
first_query_tokens = positional_encoding(first_query_tokens)
first_query_tokens.shape

torch.Size([3, 80, 1026])

In [59]:
boundary_features = get_bou_features(cur_img_feature, boundary)
boundary_features = boundary_features.permute(0, 2, 1)
boundary_features.shape

torch.Size([3, 80, 1024])

In [60]:
boundary_tokens = torch.cat([boundary_features, boundary], dim=2)
boundary_tokens.shape

torch.Size([3, 80, 1026])

In [61]:
boundary_tokens = layernorm(boundary_tokens)
boundary_tokens = positional_encoding(boundary_tokens)
boundary_tokens.shape

torch.Size([3, 80, 1026])

In [62]:
decoder_layer = nn.TransformerDecoderLayer(
    d_model=d_token,
    nhead=1,
    batch_first=True,
).to(device)
transformer_decoder = nn.TransformerDecoder(
    decoder_layer,
    num_layers=1,
).to(device)

In [69]:
input_tokens = torch.cat([
    pre_query_tokens,
    boundary_tokens,
    first_query_tokens
    ], dim=1
)
input_tokens.shape

torch.Size([3, 240, 1026])

In [71]:
out_tokens = transformer_decoder(input_tokens, memory)
out_tokens.shape

torch.Size([3, 240, 1026])

In [72]:
xy_offset_encoder = nn.Linear(d_token, 2).to(device)
q_offset_encoder = nn.Linear(d_token, d_token).to(device)

In [73]:
q_offset = q_offset_encoder(out_tokens[:, :boundary_num, :])
q_offset.shape

torch.Size([3, 80, 1026])

In [74]:
xy_offset = xy_offset_encoder(out_tokens[:, boundary_num:2*boundary_num, :])
xy_offset.shape

torch.Size([3, 80, 2])

In [75]:
(boundary + xy_offset).shape

torch.Size([3, 80, 2])