In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet50, ResNet50_Weights
from torchsummary import summary
import numpy as np
from model import PositionalEncoding
from model import IterWholeFirst
from einops import rearrange
import torchvision.transforms as T
from PIL import Image
from featup.util import norm, unnorm
from featup.plotting import plot_feats

In [2]:
!nvidia-smi

Fri Apr  5 23:13:59 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.14              Driver Version: 550.54.14      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla V100-SXM2-32GB           On  |   00000000:15:00.0 Off |                    0 |
| N/A   32C    P0             57W /  300W |    7641MiB /  32768MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  Tesla V100-SXM2-32GB           On  |   00

In [3]:
!export CUDA_VISIBLE_DEVICES=2

In [4]:
def get_bou_features(
    img_features: torch.Tensor, boundary: torch.Tensor
) -> torch.Tensor:
    bou_features = img_features[
        0, :, boundary[0, :, 0], boundary[0, :, 1]
    ].unsqueeze(0)
    for i in range(1, boundary.shape[0]):
        bou_features = torch.cat(
            (
                bou_features,
                img_features[
                    i,
                    :,
                    boundary[i, :, 0],
                    boundary[i, :, 1],
                ].unsqueeze(0),
            ),
            dim=0,
        )
    return bou_features

In [5]:
# d_token = 1024 + 2
boundary_num = 80
device = "cuda"
backbone = torch.hub.load('pytorch/vision:v0.10.0', 'deeplabv3_resnet50', pretrained=True).to(device)
# positional_encoding = PositionalEncoding(d_token).to(device)
# layernorm = nn.LayerNorm(d_token).to(device)
# query_encoder = nn.Linear(1024, d_token).to(device)
# tranformer_encoder = nn.TransformerEncoder(
#     nn.TransformerEncoderLayer(
#         d_model=d_token,
#         nhead=1,
#         batch_first=True,
#     ),
#     num_layers=1,
# ).to(device)
# xy_offset_encoder = nn.Linear(d_token, 2).to(device)
# q_offset_encoder = nn.Linear(d_token, d_token).to(device)
# xy_fc_list = nn.ModuleList()
# for i in range(boundary_num):
#     xy_fc_list.append(nn.Linear(d_token, 2))
# xy_fc_list = xy_fc_list.to(device)
# q_fc_list = nn.ModuleList()
# for i in range(boundary_num):
#     q_fc_list.append(nn.Linear(d_token, d_token))
# q_fc_list = q_fc_list.to(device)

Using cache found in /zhome/32/f/202284/.cache/torch/hub/pytorch_vision_v0.10.0


In [6]:
first_img = torch.rand(3, 3, 224, 224).to(device)
previous_img = torch.rand(3, 3, 224, 224).to(device)
current_img = torch.rand(3, 3, 224, 224).to(device)
pre_boundary = torch.randint(0, 224, (3, boundary_num, 2)).to(device)
first_boundary = torch.randint(0, 224, (3, boundary_num, 2)).to(device)
pre_boundary.shape

torch.Size([3, 80, 2])

In [7]:
from torchvision.models.segmentation.deeplabv3 import deeplabv3_resnet50

In [8]:
from model import BaseDLV3

In [9]:
model = BaseDLV3().to(device)

Using cache found in /zhome/32/f/202284/.cache/torch/hub/pytorch_vision_v0.10.0


In [10]:
results = model(first_img, first_boundary, previous_img, current_img, pre_boundary)

In [11]:
results[-1].shape

torch.Size([3, 80, 2])

In [7]:
type(backbone)

torchvision.models.segmentation.deeplabv3.DeepLabV3

In [11]:
backbone.classifier = nn.Identity()

In [11]:
backbone(first_img)['out'].shape

torch.Size([3, 2048, 224, 224])

In [14]:
raw_feat = backbone.backbone(first_img)['out']
raw_feat.shape

torch.Size([3, 2048, 28, 28])

In [15]:
from model import get_img_tokens

In [16]:
get_img_tokens(raw_feat).shape

torch.Size([3, 784, 2048])

In [41]:
def get_img_tokens(img_features: torch.Tensor) -> torch.Tensor:
    img_tokens = rearrange(
        img_features,
        "b c h w -> b (h w) c",
    )
    return img_tokens
get_img_tokens(first_img_raw_feature).shape

torch.Size([3, 196, 1024])

In [67]:
# first_img_tokens = get_img_tokens(first_img_raw_feature)
# pre_img_tokens = get_img_tokens(pre_img_raw_feature)
cur_img_tokens = get_img_tokens(cur_img_raw_feature)
cur_img_tokens.shape

torch.Size([3, 196, 1024])

In [63]:
d_token = 1024 + 2
positional_encoding = PositionalEncoding(d_token).to(device)
layernorm = nn.LayerNorm(d_token).to(device)
encoder_layer = nn.TransformerEncoderLayer(
    d_model=d_token,
    nhead=1, 
    batch_first=True
).to(device)
tranformer_encoder = nn.TransformerEncoder(
    encoder_layer,
    num_layers=1,
).to(device)



In [44]:
img_token_encoder = nn.Linear(1024, d_token).to(device)

In [68]:
cur_img_tokens = img_token_encoder(cur_img_tokens)
cur_img_tokens = layernorm(cur_img_tokens)
cur_img_tokens = positional_encoding(cur_img_tokens)
memory = tranformer_encoder(cur_img_tokens)
memory.shape

torch.Size([3, 196, 1026])

In [46]:
first_query = get_bou_features(first_img_feature, boundary)
first_query = first_query.permute(0, 2, 1)
pre_query = get_bou_features(pre_img_feature, boundary)
pre_query = pre_query.permute(0, 2, 1)
first_query.shape


torch.Size([3, 80, 1024])

In [49]:
query_encoder = nn.Linear(1024, d_token).to(device)

In [57]:
pre_query_tokens = query_encoder(pre_query)
pre_query_tokens = layernorm(pre_query_tokens)
pre_query_tokens = positional_encoding(pre_query_tokens)
pre_query_tokens.shape

torch.Size([3, 80, 1026])

In [58]:
first_query_tokens = query_encoder(first_query)
first_query_tokens = layernorm(first_query_tokens)
first_query_tokens = positional_encoding(first_query_tokens)
first_query_tokens.shape

torch.Size([3, 80, 1026])

In [59]:
boundary_features = get_bou_features(cur_img_feature, boundary)
boundary_features = boundary_features.permute(0, 2, 1)
boundary_features.shape

torch.Size([3, 80, 1024])

In [60]:
boundary_tokens = torch.cat([boundary_features, boundary], dim=2)
boundary_tokens.shape

torch.Size([3, 80, 1026])

In [61]:
boundary_tokens = layernorm(boundary_tokens)
boundary_tokens = positional_encoding(boundary_tokens)
boundary_tokens.shape

torch.Size([3, 80, 1026])

In [62]:
decoder_layer = nn.TransformerDecoderLayer(
    d_model=d_token,
    nhead=1,
    batch_first=True,
).to(device)
transformer_decoder = nn.TransformerDecoder(
    decoder_layer,
    num_layers=1,
).to(device)

In [69]:
input_tokens = torch.cat([
    pre_query_tokens,
    boundary_tokens,
    first_query_tokens
    ], dim=1
)
input_tokens.shape

torch.Size([3, 240, 1026])

In [71]:
out_tokens = transformer_decoder(input_tokens, memory)
out_tokens.shape

torch.Size([3, 240, 1026])

In [72]:
xy_offset_encoder = nn.Linear(d_token, 2).to(device)
q_offset_encoder = nn.Linear(d_token, d_token).to(device)

In [73]:
q_offset = q_offset_encoder(out_tokens[:, :boundary_num, :])
q_offset.shape

torch.Size([3, 80, 1026])

In [74]:
xy_offset = xy_offset_encoder(out_tokens[:, boundary_num:2*boundary_num, :])
xy_offset.shape

torch.Size([3, 80, 2])

In [75]:
(boundary + xy_offset).shape

torch.Size([3, 80, 2])