In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet50, ResNet50_Weights
from torchsummary import summary
import numpy as np
from model import PositionalEncoding, get_bou_features, find_best_shift, get_img_tokens
from model import IterWholeFirst
from einops import rearrange
import torchvision.transforms as T
from PIL import Image
from featup.util import norm, unnorm
from featup.plotting import plot_feats
from collections import OrderedDict
import time

In [2]:
from DETR_model import get_res4

In [3]:
res4 = get_res4()



In [4]:
tmp = torch.rand(3, 4, 224, 224).cuda()
res4(tmp).shape

torch.Size([3, 1024, 14, 14])

In [5]:
res4

Sequential(
  (0): Conv2d(4, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
  (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (4): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)


In [2]:
boundary_num = 80
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
fir_img = torch.rand(3, 3, 224, 224).to(device)
fir_mask = torch.rand(3, 1, 224, 224).to(device)
fir_con = torch.cat((fir_img, fir_mask), dim=1)
pre_img = torch.rand(3, 3, 224, 224).to(device)
pre_mask = torch.rand(3, 1, 224, 224).to(device)
pre_con = torch.cat((pre_img, pre_mask), dim=1)
cur_img = torch.rand(3, 3, 224, 224).to(device)
cur_mask = torch.rand(3, 1, 224, 224).to(device)
pre_bou = torch.randint(0, 224, (3, boundary_num, 2)).to(device)
fir_bou = torch.randint(0, 224, (3, boundary_num, 2)).to(device)

In [5]:
tmp = fir_img * fir_mask.squeeze()
tmp.shape

torch.Size([3, 3, 224, 224])

In [4]:
featup = torch.hub.load(
    "mhamilton723/FeatUp",
    "dino16",
    use_norm=True,
).to(device)

Using cache found in /zhome/32/f/202284/.cache/torch/hub/mhamilton723_FeatUp_main
Using cache found in /zhome/32/f/202284/.cache/torch/hub/facebookresearch_dino_main


In [27]:
def add_extra_channels(conv2d: nn.Conv2d, extra_chan=1):
    """
    Add extra channels to a Conv2d layer.
    """
    device = conv2d.weight.device
    new_conv2d = nn.Conv2d(
        conv2d.in_channels + extra_chan,
        conv2d.out_channels,
        conv2d.kernel_size,
        conv2d.stride,
        conv2d.padding,
        conv2d.dilation,
        conv2d.groups,
        conv2d.bias is not None,
        conv2d.padding_mode,
    ).to(device)
    new_dict = OrderedDict()
    for name, param in new_conv2d.state_dict().items():
        new_param = conv2d.state_dict()[name]
        if new_param.shape != param.shape:
            c, _, w, h = param.shape
            pads = torch.zeros((c, extra_chan, w, h)).to(device)
            nn.init.orthogonal_(pads)
            new_param = torch.cat((new_param, pads), dim=1)
        new_dict[name] = new_param
    new_conv2d.load_state_dict(new_dict)
    return new_conv2d

In [49]:
dino16 = featup.model
dino16[0].model.patch_embed.proj

Conv2d(3, 384, kernel_size=(16, 16), stride=(16, 16))

In [50]:
new_proj = add_extra_channels(dino16[0].model.patch_embed.proj, 1)
tmp_mask = torch.zeros((3, 1, 224, 224)).to(device)
offset = new_proj(
    torch.cat(
        [fir_img, tmp_mask],
        dim=1,
    )
) - dino16[0].model.patch_embed.proj(fir_img)
new_proj

Conv2d(4, 384, kernel_size=(16, 16), stride=(16, 16))

In [51]:
def get_dino4(device="cuda") -> nn.Module:
    featup = torch.hub.load(
        "mhamilton723/FeatUp",
        "dino16",
        use_norm=True,
    ).to(device)
    dino4 = featup.model
    new_proj = add_extra_channels(dino4[0].model.patch_embed.proj, 1)
    dino4[0].model.patch_embed.proj = new_proj
    return dino4
dino4 = get_dino4()
dino4(torch.cat([fir_img, fir_mask], dim=1)).shape

Using cache found in /zhome/32/f/202284/.cache/torch/hub/mhamilton723_FeatUp_main
Using cache found in /zhome/32/f/202284/.cache/torch/hub/facebookresearch_dino_main


torch.Size([3, 384, 14, 14])

In [58]:
hidden_dim = 384

fir_con_feats = dino4(torch.cat([fir_img, fir_mask], dim=1))
pre_con_feats = dino4(torch.cat([pre_img, pre_mask], dim=1))
fir_con_tokens = get_img_tokens(fir_con_feats)
pre_con_tokens = get_img_tokens(pre_con_feats)
mem_img_tokens = torch.cat((fir_con_tokens, pre_con_tokens), dim=1)

layernorm = nn.LayerNorm(hidden_dim).to(device)
mem_img_tokens = layernorm(mem_img_tokens)

pos_enc = PositionalEncoding(hidden_dim).to(device)
mem_img_tokens = pos_enc(mem_img_tokens)
mem_img_tokens.shape

torch.Size([3, 392, 384])

In [61]:
boundary_num = 80
queries = nn.Parameter(torch.rand(boundary_num, hidden_dim)).to(device)
B, S, D = mem_img_tokens.shape
queries = queries.unsqueeze(0).expand(B, -1, -1)
transformer1 = nn.Transformer(
    d_model=hidden_dim,
    nhead=1,
    num_encoder_layers=1,
    num_decoder_layers=1,
    batch_first=True,
).to(device)
tmp = transformer1(mem_img_tokens, queries)
tmp.shape

torch.Size([3, 80, 384])

In [62]:
def get_raw_dino(device="cuda") -> nn.Module:
    featup = torch.hub.load(
        "mhamilton723/FeatUp",
        "dino16",
        use_norm=True,
    ).to(device)
    dino = featup.model
    return dino

In [66]:
raw_dino = get_raw_dino()
cur_img_feats = raw_dino(cur_img)
cur_img_tokens = get_img_tokens(cur_img_feats)

Using cache found in /zhome/32/f/202284/.cache/torch/hub/mhamilton723_FeatUp_main
Using cache found in /zhome/32/f/202284/.cache/torch/hub/facebookresearch_dino_main


In [67]:
cur_img_tokens = layernorm(cur_img_tokens)
cur_img_tokens = pos_enc(cur_img_tokens)
cur_img_tokens.shape

torch.Size([3, 196, 384])

In [68]:
transformer2 = nn.Transformer(
    d_model=hidden_dim,
    nhead=1,
    num_encoder_layers=1,
    num_decoder_layers=1,
    batch_first=True,
).to(device)

output = transformer2(cur_img_tokens, tmp)
output.shape



torch.Size([3, 80, 384])

In [77]:
xy_fc = nn.Linear(hidden_dim, 2).to(device)
norm_xy = xy_fc(output).sigmoid()
norm_xy.max(), norm_xy.min()

(tensor(0.7364, device='cuda:0', grad_fn=<MaxBackward1>),
 tensor(0.1657, device='cuda:0', grad_fn=<MinBackward1>))

In [78]:
xy = norm_xy * 224
xy.max(), xy.min()

(tensor(164.9437, device='cuda:0', grad_fn=<MaxBackward1>),
 tensor(37.1275, device='cuda:0', grad_fn=<MinBackward1>))

In [96]:
class DinoDETR(nn.Module):
    def __init__(self, boundary_num=80, device="cuda"):
        super(DinoDETR, self).__init__()
        self.dino4 = get_dino4()
        self.raw_dino = get_raw_dino()
        # freeze raw_dino
        for param in self.raw_dino.parameters():
            param.requires_grad = False
        self.hidden_dim = 384
        self.layernorm = nn.LayerNorm(self.hidden_dim).to(device)
        self.pos_enc = PositionalEncoding(self.hidden_dim).to(device)
        self.boundary_num = boundary_num
        self.queries = nn.Parameter(
            torch.rand(boundary_num, self.hidden_dim),
        ).to(device)
        self.transformer1 = nn.Transformer(
            d_model=self.hidden_dim,
            nhead=1,
            num_encoder_layers=1,
            num_decoder_layers=1,
            batch_first=True,
        ).to(device)
        self.transformer2 = nn.Transformer(
            d_model=self.hidden_dim,
            nhead=1,
            num_encoder_layers=1,
            num_decoder_layers=1,
            batch_first=True,
        ).to(device)
        self.xy_fc = nn.Linear(self.hidden_dim, 2).to(device)

    def forward(
        self,
        fir_img: torch.Tensor,
        fir_mask: torch.Tensor,
        pre_img: torch.Tensor,
        pre_mask: torch.Tensor,
        cur_img: torch.Tensor,
    ) -> torch.Tensor:
        fir_con_feats = self.dino4(torch.cat([fir_img, fir_mask], dim=1))
        pre_con_feats = self.dino4(torch.cat([pre_img, pre_mask], dim=1))
        fir_con_tokens = get_img_tokens(fir_con_feats)
        pre_con_tokens = get_img_tokens(pre_con_feats)
        mem_img_tokens = torch.cat((fir_con_tokens, pre_con_tokens), dim=1)
        mem_img_tokens = self.layernorm(mem_img_tokens)
        mem_img_tokens = self.pos_enc(mem_img_tokens)
        B, S, D = mem_img_tokens.shape
        queries = self.queries.unsqueeze(0).expand(B, -1, -1)
        mem_img_tokens = mem_img_tokens
        x = self.transformer1(mem_img_tokens, queries)
        cur_img_feats = self.raw_dino(cur_img)
        cur_img_tokens = get_img_tokens(cur_img_feats)
        cur_img_tokens = self.layernorm(cur_img_tokens)
        cur_img_tokens = self.pos_enc(cur_img_tokens)
        x = self.transformer2(cur_img_tokens, x)
        norm_xy = self.xy_fc(x).sigmoid()
        xy = norm_xy * 224
        return xy



In [97]:
model = DinoDETR().to(device)

Using cache found in /zhome/32/f/202284/.cache/torch/hub/mhamilton723_FeatUp_main
Using cache found in /zhome/32/f/202284/.cache/torch/hub/facebookresearch_dino_main
Using cache found in /zhome/32/f/202284/.cache/torch/hub/mhamilton723_FeatUp_main
Using cache found in /zhome/32/f/202284/.cache/torch/hub/facebookresearch_dino_main


In [100]:
start_time = time.time()
print(model(fir_img, fir_mask, pre_img, pre_mask, cur_img))
end_time = time.time()
print(f"Time: {end_time - start_time}")

tensor([[[129.4344, 153.6473],
         [ 99.5774, 105.5077],
         [ 82.0723, 148.6579],
         [ 79.6153, 161.4762],
         [ 74.4268, 152.0893],
         [111.8938, 123.1397],
         [ 87.8885, 155.5473],
         [ 98.8817, 161.5239],
         [107.7761, 146.3063],
         [ 81.8827, 158.4506],
         [116.5304, 137.9882],
         [ 89.2141, 145.6071],
         [ 97.6094, 138.0553],
         [ 89.5063, 132.3686],
         [ 98.5212, 174.6442],
         [ 91.2148, 137.9995],
         [114.2734, 152.7850],
         [100.6595, 159.2484],
         [ 97.6542, 119.3128],
         [134.6010, 151.2711],
         [ 98.0765, 140.0412],
         [118.0413, 162.2771],
         [103.3818, 170.9048],
         [113.5168, 118.3595],
         [ 76.5599, 164.0641],
         [ 89.3307, 165.6019],
         [ 70.0166, 154.9545],
         [ 97.1229, 161.1718],
         [114.8269, 155.5305],
         [106.6668, 168.3149],
         [ 89.1700, 163.9519],
         [107.4631, 155.6739],
        