In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet50, ResNet50_Weights
import numpy as np

In [2]:
torch.cuda.is_available()

True

In [3]:
res50 = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
# summary(res50, (3, 224, 224))

Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /zhome/32/f/202284/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth
100%|██████████████████████████████| 97.8M/97.8M [00:01<00:00, 75.4MB/s]


In [3]:
res50_bone = nn.Sequential(*list(res50.children())[:-2]).to('cpu')
# summary(res50_bone, (3, 224, 224))

In [4]:
pre_frame = torch.randn(3, 3, 224, 224)
pre_frame_features = res50_bone(pre_frame)
curr_frame = torch.randn(3, 3, 224, 224)
curr_img_features = res50_bone(curr_frame)
curr_img_features.shape, pre_frame_features.shape

(torch.Size([3, 2048, 7, 7]), torch.Size([3, 2048, 7, 7]))

In [5]:
tmp_boundary = torch.Tensor(
    np.array(
        [
            [[1, 2], [3, 4], [5, 6], [7, 8]],
            [[9, 10], [11, 12], [13, 14], [15, 16]],
            [[9, 10], [11, 12], [13, 14], [15, 16]],
        ],
        dtype=np.int32,
    )
).int()
tmp_boundary.shape

torch.Size([3, 4, 2])

In [6]:
curr_img_features = F.interpolate(curr_img_features, size=(224, 224), mode="bilinear")
pre_img_features = F.interpolate(pre_frame_features, size=(224, 224), mode="bilinear")
curr_img_features.shape, pre_img_features.shape

(torch.Size([3, 2048, 224, 224]), torch.Size([3, 2048, 224, 224]))

In [7]:
tmp_boundary[2, :, 0]

tensor([ 9, 11, 13, 15], dtype=torch.int32)

In [39]:
curr_img_features[0, :, tmp_boundary[0, :, 0].to("cpu"), tmp_boundary[0, :, 1].to("cpu")].shape

torch.Size([2048, 4])

In [8]:
def get_bou_features(
    img_features: torch.Tensor, boundary: torch.Tensor
) -> torch.Tensor:
    bou_features = img_features[0, :, boundary[0, :, 0], boundary[0, :, 1]].unsqueeze(0)
    for i in range(1, boundary.shape[0]):
        bou_features = torch.cat(
            (
                bou_features,
                img_features[
                    i,
                    :,
                    boundary[i, :, 0],
                    boundary[i, :, 1],
                ].unsqueeze(0),
            ),
            dim=0,
        )
    return bou_features

In [9]:
# curr_bou_features = curr_img_features[
#     0, :, tmp_boundary[0, :, 0], tmp_boundary[0, :, 1]
# ].unsqueeze(0)
# for i in range(1, tmp_boundary.shape[0]):
#     curr_bou_features = torch.cat(
#         (
#             curr_bou_features,
#             curr_img_features[
#                 i, :, tmp_boundary[i, :, 0], tmp_boundary[i, :, 1]
#             ].unsqueeze(0),
#         ),
#         dim=0,
#     )
curr_bou_features = get_bou_features(curr_img_features, tmp_boundary)
pre_bou_features = get_bou_features(pre_img_features, tmp_boundary)
curr_bou_features = curr_bou_features.permute(0, 2, 1)
pre_bou_features = pre_bou_features.permute(0, 2, 1)
curr_bou_features.shape, pre_bou_features.shape

(torch.Size([3, 4, 2048]), torch.Size([3, 4, 2048]))

In [10]:
tokens = torch.cat([curr_bou_features, pre_bou_features], dim=2)
tokens.shape

torch.Size([3, 4, 4096])

In [41]:
tmp_boundary.shape

torch.Size([3, 4, 2])

In [12]:
final_tokens = torch.cat([tokens, tmp_boundary.float()], dim=2)
final_tokens.shape

torch.Size([3, 4, 4098])

In [13]:
final_tokens[2, 2, :]

tensor([ 0.,  0.,  0.,  ...,  0., 13., 14.], grad_fn=<SliceBackward0>)

In [14]:
encoder_layer = nn.TransformerEncoderLayer(d_model=4098, nhead=6, batch_first=True)
transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)

In [15]:
transformer_encoder(final_tokens).shape

torch.Size([3, 4, 4098])

In [16]:
from model import Model

In [19]:
model = Model(4).to("cuda")
pre_frame = torch.randn(3, 3, 224, 224).to("cuda")
curr_frame = torch.randn(3, 3, 224, 224).to("cuda")
tmp_boundary = tmp_boundary.to("cuda")

In [40]:
model(pre_frame, curr_frame, tmp_boundary).shape

torch.Size([3, 8])