In [24]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet50, ResNet50_Weights
from torchsummary import summary
import numpy as np
from model import PositionalEncoding, get_bou_features, find_best_shift
from model import IterWholeFirst
from einops import rearrange
import torchvision.transforms as T
from PIL import Image
from featup.util import norm, unnorm
from featup.plotting import plot_feats
from collections import OrderedDict

In [2]:
boundary_num = 80
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
first_img = torch.rand(3, 3, 224, 224).to(device)
previous_img = torch.rand(3, 3, 224, 224).to(device)
current_img = torch.rand(3, 3, 224, 224).to(device)
pre_boundary = torch.randint(0, 224, (3, boundary_num, 2)).to(device)
first_boundary = torch.randint(0, 224, (3, boundary_num, 2)).to(device)
pre_boundary.shape

torch.Size([3, 80, 2])

In [3]:
featup_backbone = torch.hub.load(
    "mhamilton723/FeatUp",
    "dino16",
    use_norm=True,
).to(device)

Using cache found in /zhome/32/f/202284/.cache/torch/hub/mhamilton723_FeatUp_main
Using cache found in /zhome/32/f/202284/.cache/torch/hub/facebookresearch_dino_main


In [18]:
patch_embed_proj = featup_backbone.model[0].model.patch_embed.proj
patch_embed_proj

Conv2d(3, 384, kernel_size=(16, 16), stride=(16, 16))

In [13]:
first_mask = torch.rand(3, 1, 224, 224).to(device)
first_mask.shape

torch.Size([3, 1, 224, 224])

In [14]:
first_con = torch.cat([first_img, first_mask], dim=1)
first_con.shape

torch.Size([3, 4, 224, 224])

In [16]:
patch_embed_proj(first_img).shape

torch.Size([3, 384, 14, 14])

In [25]:
def load_weights_sequential(target, source_state, extra_chan=1):

    new_dict = OrderedDict()

    for k1, v1 in target.state_dict().items():
        if not "num_batches_tracked" in k1:
            if k1 in source_state:
                tar_v = source_state[k1]

                if v1.shape != tar_v.shape:
                    # Init the new segmentation channel with zeros
                    # print(v1.shape, tar_v.shape)
                    c, _, w, h = v1.shape
                    pads = torch.zeros((c, extra_chan, w, h), device=tar_v.device)
                    nn.init.orthogonal_(pads)
                    tar_v = torch.cat([tar_v, pads], 1)

                new_dict[k1] = tar_v

    target.load_state_dict(new_dict, strict=False)

In [28]:
new_patch_embed_proj = nn.Conv2d(4, 384, 16, 16).to(device)
load_weights_sequential(new_patch_embed_proj, patch_embed_proj.state_dict())
new_patch_embed_proj(first_con).shape

torch.Size([3, 384, 14, 14])

In [29]:
featup_backbone(first_img).shape

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


torch.Size([3, 384, 224, 224])

In [31]:
featup_backbone.model[0].model.patch_embed.proj = new_patch_embed_proj

In [32]:
featup_backbone(first_con).shape

RuntimeError: Given groups=1, weight of size [32, 3, 1, 1], expected input[3, 4, 28, 28] to have 3 channels, but got 4 channels instead

In [8]:
featup_backbone.model[0].model

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 384, kernel_size=(16, 16), stride=(16, 16))
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (blocks): ModuleList(
    (0-11): 12 x Block(
      (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=384, out_features=1152, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=384, out_features=384, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=384, out_features=1536, bias=True)
        (act): GELU(approximate='none')
        (fc2): Linear(in_features=1536, out_features=384, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
  )
  (norm): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
  (head): Identity()
)

In [6]:
pre_img_feats = featup_backbone(previous_img)
cur_img_feats = featup_backbone(current_img)
fir_img_feats = featup_backbone(first_img)
pre_img_feats.shape, cur_img_feats.shape, fir_img_feats.shape

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


(torch.Size([3, 384, 224, 224]),
 torch.Size([3, 384, 224, 224]),
 torch.Size([3, 384, 224, 224]))

In [14]:
def get_neighbor_features(
    img_features: torch.Tensor,
    boundary: torch.Tensor,
    scale_level: int,
):
    device = img_features.device
    if scale_level > 0:
        img_features = F.avg_pool2d(
            img_features,
            2**scale_level,
            2**scale_level,
        )
        boundary = boundary // (2**scale_level)
    img_features = F.pad(img_features, (1, 1, 1, 1), "constant", 0)
    neighor_features = []
    for x_offset in range(-1, 2):
        for y_offset in range(-1, 2):
            neighor_features.append(
                get_bou_features(
                    img_features,
                    boundary + torch.tensor([x_offset, y_offset]).to(device),
                )
            )
    return torch.cat(neighor_features, dim=2)

In [25]:
neigh_feats = []
for scale_level in range(4):
    neigh_feats.append(
        get_neighbor_features(
            cur_img_feats,
            pre_boundary,
            scale_level,
        ).permute(0, 2, 1)
    )

In [26]:
for neigh_feat in neigh_feats:
    print(neigh_feat.shape)

torch.Size([3, 720, 384])
torch.Size([3, 720, 384])
torch.Size([3, 720, 384])
torch.Size([3, 720, 384])


In [27]:
neigh_token_fc = nn.Linear(384, 384 + 2).to(device)

In [28]:
neigh_tokens = []
for scale_level in range(4):
    neigh_tokens.append(neigh_token_fc(neigh_feats[scale_level]))
    print(neigh_tokens[scale_level].shape)

torch.Size([3, 720, 386])
torch.Size([3, 720, 386])
torch.Size([3, 720, 386])
torch.Size([3, 720, 386])


In [29]:
layer_norm = nn.LayerNorm(384 + 2).to(device)
position_enc = PositionalEncoding(384 + 2).to(device)

In [30]:
for scale_level in range(4):
    neigh_tokens[scale_level] = position_enc(
        layer_norm(
            neigh_tokens[scale_level],
        )
    )
    print(neigh_tokens[scale_level].shape)

torch.Size([3, 720, 386])
torch.Size([3, 720, 386])
torch.Size([3, 720, 386])
torch.Size([3, 720, 386])


In [31]:
encoder_layer = nn.TransformerEncoderLayer(
    d_model=384 + 2,
    nhead=1,
    batch_first=True,
)
transformer_encoder = nn.TransformerEncoder(
    encoder_layer,
    num_layers=1,
).to(device)



In [33]:
memorys = []
for scale_level in range(4):
    memorys.append(
        transformer_encoder(
            neigh_tokens[scale_level],
        )
    )
    print(memorys[scale_level].shape)

torch.Size([3, 720, 386])
torch.Size([3, 720, 386])
torch.Size([3, 720, 386])
torch.Size([3, 720, 386])


In [34]:
pre_bou_feats = get_bou_features(pre_img_feats, pre_boundary)
cur_bou_feats = get_bou_features(cur_img_feats, pre_boundary)
fir_bou_feats = get_bou_features(fir_img_feats, first_boundary)
pre_bou_feats.shape, cur_bou_feats.shape, fir_bou_feats.shape

(torch.Size([3, 384, 80]), torch.Size([3, 384, 80]), torch.Size([3, 384, 80]))

In [35]:
fir_bou_feats = fir_bou_feats.permute(0, 2, 1)
pre_bou_feats = pre_bou_feats.permute(0, 2, 1)
cur_bou_feats = cur_bou_feats.permute(0, 2, 1)
fir_bou_feats.shape, pre_bou_feats.shape, cur_bou_feats.shape

(torch.Size([3, 80, 384]), torch.Size([3, 80, 384]), torch.Size([3, 80, 384]))

In [39]:
fir_tokens = torch.cat([fir_bou_feats, first_boundary.float()], dim=2)
pre_tokens = torch.cat([pre_bou_feats, pre_boundary.float()], dim=2)
cur_tokens = torch.cat([cur_bou_feats, pre_boundary.float()], dim=2)
fir_tokens.shape, pre_tokens.shape, cur_tokens.shape

(torch.Size([3, 80, 386]), torch.Size([3, 80, 386]), torch.Size([3, 80, 386]))

In [38]:
decoder_layer = nn.TransformerDecoderLayer(
    d_model=384 + 2,
    nhead=1,
    batch_first=True,
)
transformer_decoder = nn.TransformerDecoder(
    decoder_layer,
    num_layers=1,
).to(device)

In [41]:
input_tokens = torch.cat([pre_tokens, cur_tokens, fir_tokens], dim=1)
input_tokens = layer_norm(input_tokens)
input_tokens = position_enc(input_tokens)
input_tokens.shape

torch.Size([3, 240, 386])

In [42]:
for scale_level in reversed(range(4)):
    input_tokens = transformer_decoder(
        input_tokens,
        memorys[scale_level],
    )
    print(input_tokens.shape)

torch.Size([3, 240, 386])
torch.Size([3, 240, 386])
torch.Size([3, 240, 386])
torch.Size([3, 240, 386])


In [43]:
out_tokens = input_tokens[
    :,
    boundary_num : 2 * boundary_num,
    :,
]
out_tokens.shape

torch.Size([3, 80, 386])

In [44]:
xy_offset_fc = nn.Linear(384 + 2, 2).to(device)
xy_offset = xy_offset_fc(out_tokens)
xy_offset.shape

torch.Size([3, 80, 2])

In [46]:
class FeatupNei(nn.Module):
    def __init__(self):
        super(FeatupNei, self).__init__()
        self.bou_num = 80
        d_token = 384 + 2
        self.scale_num = 4

        self.backbone = torch.hub.load(
            "mhamilton723/FeatUp",
            "dino16",
            use_norm=True,
        ).to(device)
        # freeze the backbone
        for param in self.backbone.parameters():
            param.requires_grad = False

        self.neighor_fc = nn.Linear(384, d_token).to(device)
        self.layer_norm = nn.LayerNorm(d_token).to(device)
        self.position_enc = PositionalEncoding(d_token).to(device)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_token,
            nhead=1,
            batch_first=True,
        )
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer,
            num_layers=1,
        ).to(device)
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=d_token,
            nhead=1,
            batch_first=True,
        )
        self.transformer_decoder = nn.TransformerDecoder(
            decoder_layer,
            num_layers=1,
        ).to(device)
        self.xy_offset_fc = nn.Linear(d_token, 2).to(device)

    def forward(
        self,
        fir_img: torch.Tensor,
        fir_bou: torch.Tensor,
        pre_img: torch.Tensor,
        cur_img: torch.Tensor,
        pre_bou: torch.Tensor,
    ):
        best_shift = find_best_shift(pre_bou, fir_bou)
        for i in range(len(best_shift)):
            fir_bou[i] = fir_bou[i].roll(best_shift[i], 0)

        fir_img_feats = self.backbone(fir_img)
        pre_img_feats = self.backbone(pre_img)
        cur_img_feats = self.backbone(cur_img)

        neigh_feats = []
        for scale_level in range(self.scale_num):
            neigh_feats.append(
                get_neighbor_features(
                    cur_img_feats,
                    pre_bou,
                    scale_level,
                ).permute(0, 2, 1)
            )
        neigh_tokens = []
        for scale_level in range(self.scale_num):
            neigh_tokens.append(
                self.neighor_fc(
                    neigh_feats[scale_level],
                ),
            )
        for scale_level in range(self.scale_num):
            neigh_tokens[scale_level] = self.position_enc(
                self.layer_norm(
                    neigh_tokens[scale_level],
                ),
            )
        memorys = []
        for scale_level in range(self.scale_num):
            memorys.append(
                self.transformer_encoder(
                    neigh_tokens[scale_level],
                ),
            )
        
        pre_bou_feats = get_bou_features(pre_img_feats, pre_bou)
        cur_bou_feats = get_bou_features(cur_img_feats, pre_bou)
        fir_bou_feats = get_bou_features(fir_img_feats, fir_bou)
        fir_bou_feats = fir_bou_feats.permute(0, 2, 1)
        pre_bou_feats = pre_bou_feats.permute(0, 2, 1)
        cur_bou_feats = cur_bou_feats.permute(0, 2, 1)
        fir_tokens = torch.cat([fir_bou_feats, fir_bou.float()], dim=2)
        pre_tokens = torch.cat([pre_bou_feats, pre_bou.float()], dim=2)
        cur_tokens = torch.cat([cur_bou_feats, pre_bou.float()], dim=2)
        input_tokens = torch.cat([pre_tokens, cur_tokens, fir_tokens], dim=1)
        input_tokens = self.layer_norm(input_tokens)
        input_tokens = self.position_enc(input_tokens)
        for scale_level in reversed(range(self.scale_num)):
            input_tokens = self.transformer_decoder(
                input_tokens,
                memorys[scale_level],
            )
        out_tokens = input_tokens[:, self.bou_num : 2 * self.bou_num, :]
        xy_offset = self.xy_offset_fc(out_tokens)
        result = xy_offset + pre_bou.float()
        return [result]

In [47]:
model = FeatupNei().to(device)

Using cache found in /zhome/32/f/202284/.cache/torch/hub/mhamilton723_FeatUp_main
Using cache found in /zhome/32/f/202284/.cache/torch/hub/facebookresearch_dino_main


In [48]:
model(
    first_img,
    first_boundary,
    previous_img,
    current_img,
    pre_boundary,
)[0]

tensor([[[153.1673, 195.6584],
         [ 27.4937,  31.4670],
         [150.3264,  70.5704],
         [156.4951, 104.6779],
         [213.1922, 155.4496],
         [132.9963,  62.5916],
         [214.2008, 121.6199],
         [ 68.6694, 137.4595],
         [205.1222,  80.6518],
         [136.2520,  90.4591],
         [216.7162,  96.6581],
         [ 19.2738,  97.3381],
         [ 38.0802, 188.5018],
         [ 79.4439,  45.3750],
         [ 92.1500, 184.3403],
         [204.6061,  29.6343],
         [ 31.1597,  50.6149],
         [ 50.5443, 113.3948],
         [ 42.0186, 116.3720],
         [ 73.3154, 119.4523],
         [110.8636, 196.4116],
         [ 94.1512, 153.4003],
         [120.6401,  19.5699],
         [ 73.4373, 106.4883],
         [160.1704, 212.2492],
         [ 48.2539, 169.3328],
         [146.7191, 179.5117],
         [  9.8852, 199.5704],
         [ 40.9352, 210.2940],
         [ 16.4163,  60.3045],
         [ 37.9838, 184.9747],
         [ 43.0824,  54.5503],
        