# Packages

In [1]:
import torch.nn as nn
from typing import Optional, List
from torch import Tensor
import torch
import math
import torchvision
from torchvision.models._utils import IntermediateLayerGetter
import copy
import torch.nn.functional as F

  from .autonotebook import tqdm as notebook_tqdm


# Backbone

In [2]:
class NestedTensor(object):
    def __init__(self, tensors, mask: Optional[Tensor]):
        self.tensors = tensors
        self.mask = mask

    def to(self, device):
        # type: (Device) -> NestedTensor # noqa
        cast_tensor = self.tensors.to(device)
        mask = self.mask
        if mask is not None:
            assert mask is not None
            cast_mask = mask.to(device)
        else:
            cast_mask = None
        return NestedTensor(cast_tensor, cast_mask)

    def decompose(self):
        return self.tensors, self.mask

    def __repr__(self):
        return str(self.tensors)

In [3]:
class PositionEmbeddingSine(nn.Module):
    """
    This is a more standard version of the position embedding, very similar to the one
    used by the Attention is all you need paper, generalized to work on images.
    """
    def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
        super().__init__()
        self.num_pos_feats = num_pos_feats
        self.temperature = temperature
        self.normalize = normalize
        if scale is not None and normalize is False:
            raise ValueError("normalize should be True if scale is passed")
        if scale is None:
            scale = 2 * math.pi
        self.scale = scale

    def forward(self, tensor_list: NestedTensor):
        x = tensor_list.tensors
        mask = tensor_list.mask
        assert mask is not None
        not_mask = ~mask
        y_embed = not_mask.cumsum(1, dtype=torch.float32)
        x_embed = not_mask.cumsum(2, dtype=torch.float32)
        if self.normalize:
            eps = 1e-6
            y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
            x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale

        dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
        dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)

        pos_x = x_embed[:, :, :, None] / dim_t
        pos_y = y_embed[:, :, :, None] / dim_t
        pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
        pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
        pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
        return pos



In [4]:
def build_position_encoding(args):
    N_steps = args["hidden_dim"] // 2
    position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
    return position_embedding

In [5]:
class BackboneBase(nn.Module):

    def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool):
        super().__init__()
        for name, parameter in backbone.named_parameters():
            if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
                parameter.requires_grad_(False)
        if return_interm_layers:
            return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
        else:
            return_layers = {'layer4': "0"}
        self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
        self.num_channels = num_channels

    def forward(self, tensor_list: NestedTensor):
        xs = self.body(tensor_list.tensors)
        out: Dict[str, NestedTensor] = {}
        for name, x in xs.items():
            m = tensor_list.mask
            assert m is not None
            mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
            out[name] = NestedTensor(x, mask)
        return out

In [6]:
class Backbone(BackboneBase):
    """ResNet backbone with frozen BatchNorm."""
    def __init__(self, name: str,
                 train_backbone: bool,
                 return_interm_layers: bool,
                 dilation: bool):
        backbone = getattr(torchvision.models, name)(
            replace_stride_with_dilation=[False, False, dilation],
            pretrained=True, norm_layer=FrozenBatchNorm2d)
        num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
        super().__init__(backbone, train_backbone, num_channels, return_interm_layers)

In [7]:
class Joiner(nn.Sequential):
    def __init__(self, backbone, position_embedding):
        super().__init__(backbone, position_embedding)

    def forward(self, tensor_list: NestedTensor):
        xs = self[0](tensor_list)
        out: List[NestedTensor] = []
        pos = []
        for name, x in xs.items():
            out.append(x)
            # position encoding
            pos.append(self[1](x).to(x.tensors.dtype))

        return out, pos

In [8]:
args = {
    "hidden_dim": 256,
    "lr_backbone": 1e-5,
    'masks': 'store_true',
    'backbone': 'resnet50',
    'dilation': 'store_true'
}

In [9]:
class FrozenBatchNorm2d(torch.nn.Module):
    """
    BatchNorm2d where the batch statistics and the affine parameters are fixed.
    Copy-paste from torchvision.misc.ops with added eps before rqsrt,
    without which any other models than torchvision.models.resnet[18,34,50,101]
    produce nans.
    """

    def __init__(self, n):
        super(FrozenBatchNorm2d, self).__init__()
        self.register_buffer("weight", torch.ones(n))
        self.register_buffer("bias", torch.zeros(n))
        self.register_buffer("running_mean", torch.zeros(n))
        self.register_buffer("running_var", torch.ones(n))

    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
                              missing_keys, unexpected_keys, error_msgs):
        num_batches_tracked_key = prefix + 'num_batches_tracked'
        if num_batches_tracked_key in state_dict:
            del state_dict[num_batches_tracked_key]

        super(FrozenBatchNorm2d, self)._load_from_state_dict(
            state_dict, prefix, local_metadata, strict,
            missing_keys, unexpected_keys, error_msgs)

    def forward(self, x):
        # move reshapes to the beginning
        # to make it fuser-friendly
        w = self.weight.reshape(1, -1, 1, 1)
        b = self.bias.reshape(1, -1, 1, 1)
        rv = self.running_var.reshape(1, -1, 1, 1)
        rm = self.running_mean.reshape(1, -1, 1, 1)
        eps = 1e-5
        scale = w * (rv + eps).rsqrt()
        bias = b - rm * scale
        return x * scale + bias

In [10]:
position_embedding = build_position_encoding(args)
train_backbone = args['lr_backbone'] > 0
return_interm_layers = ['masks']
backbone = Backbone(args['backbone'], train_backbone, return_interm_layers, args['dilation'])
model = Joiner(backbone, position_embedding)
model.num_channels = backbone.num_channels

In [11]:
def build_backbone(args):
    position_embedding = build_position_encoding(args)
    train_backbone = args['lr_backbone'] > 0
    return_interm_layers = args['masks']
    backbone = Backbone(args['backbone'], train_backbone, return_interm_layers, args['dilation'])
    model = Joiner(backbone, position_embedding)
    model.num_channels = backbone.num_channels
    return model

In [12]:
backbone = build_backbone(args)

# Transformer

In [13]:
class TransformerEncoderLayer(nn.Module):

    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
                 activation="relu", normalize_before=False):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        # Implementation of Feedforward model
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

        self.activation = _get_activation_fn(activation)
        self.normalize_before = normalize_before

    def with_pos_embed(self, tensor, pos: Optional[Tensor]):
        return tensor if pos is None else tensor + pos

    def forward_post(self,
                     src,
                     src_mask: Optional[Tensor] = None,
                     src_key_padding_mask: Optional[Tensor] = None,
                     pos: Optional[Tensor] = None):
        q = k = self.with_pos_embed(src, pos)  # token特征src和位置特征pos相加作为query和key
                                                # value即src不需要加pos TODO：why？
        src2 = self.self_attn(q, k, value=src, attn_mask=src_mask,
                              key_padding_mask=src_key_padding_mask)[0]
        src = src + self.dropout1(src2)
        src = self.norm1(src)
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
        src = src + self.dropout2(src2)
        src = self.norm2(src)
        return src

    def forward_pre(self, src,
                    src_mask: Optional[Tensor] = None,
                    src_key_padding_mask: Optional[Tensor] = None,
                    pos: Optional[Tensor] = None):
        src2 = self.norm1(src)
        q = k = self.with_pos_embed(src2, pos)
        src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask,
                              key_padding_mask=src_key_padding_mask)[0]
        src = src + self.dropout1(src2)
        src2 = self.norm2(src)
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
        src = src + self.dropout2(src2)
        return src

    def forward(self, src,
                src_mask: Optional[Tensor] = None,
                src_key_padding_mask: Optional[Tensor] = None,
                pos: Optional[Tensor] = None):
        if self.normalize_before:
            return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
        return self.forward_post(src, src_mask, src_key_padding_mask, pos)

In [14]:
class TransformerDecoderLayer(nn.Module):

    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
                 activation="relu", normalize_before=False):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        # Implementation of Feedforward model
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)

        self.activation = _get_activation_fn(activation)
        self.normalize_before = normalize_before

    def with_pos_embed(self, tensor, pos: Optional[Tensor]):
        return tensor if pos is None else tensor + pos

    def forward_post(self, tgt, memory,
                     tgt_mask: Optional[Tensor] = None,
                     memory_mask: Optional[Tensor] = None,
                     tgt_key_padding_mask: Optional[Tensor] = None,
                     memory_key_padding_mask: Optional[Tensor] = None,
                     pos: Optional[Tensor] = None,
                     query_pos: Optional[Tensor] = None):
        q = k = self.with_pos_embed(tgt, query_pos)
        tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
                              key_padding_mask=tgt_key_padding_mask)[0]
        tgt = tgt + self.dropout1(tgt2)
        tgt = self.norm1(tgt)
        tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
                                   key=self.with_pos_embed(memory, pos),
                                   value=memory, attn_mask=memory_mask,
                                   key_padding_mask=memory_key_padding_mask)[0]
        tgt = tgt + self.dropout2(tgt2)
        tgt = self.norm2(tgt)
        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
        tgt = tgt + self.dropout3(tgt2)
        tgt = self.norm3(tgt)
        return tgt

    def forward_pre(self, tgt, memory,
                    tgt_mask: Optional[Tensor] = None,
                    memory_mask: Optional[Tensor] = None,
                    tgt_key_padding_mask: Optional[Tensor] = None,
                    memory_key_padding_mask: Optional[Tensor] = None,
                    pos: Optional[Tensor] = None,
                    query_pos: Optional[Tensor] = None):
        tgt2 = self.norm1(tgt)
        q = k = self.with_pos_embed(tgt2, query_pos)
        tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
                              key_padding_mask=tgt_key_padding_mask)[0]
        tgt = tgt + self.dropout1(tgt2)
        tgt2 = self.norm2(tgt)
        tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
                                   key=self.with_pos_embed(memory, pos),
                                   value=memory, attn_mask=memory_mask,
                                   key_padding_mask=memory_key_padding_mask)[0]
        tgt = tgt + self.dropout2(tgt2)
        tgt2 = self.norm3(tgt)
        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
        tgt = tgt + self.dropout3(tgt2)
        return tgt

    def forward(self, tgt, memory,
                tgt_mask: Optional[Tensor] = None,
                memory_mask: Optional[Tensor] = None,
                tgt_key_padding_mask: Optional[Tensor] = None,
                memory_key_padding_mask: Optional[Tensor] = None,
                pos: Optional[Tensor] = None,
                query_pos: Optional[Tensor] = None):
        if self.normalize_before:
            return self.forward_pre(tgt, memory, tgt_mask, memory_mask,
                                    tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
        return self.forward_post(tgt, memory, tgt_mask, memory_mask,
                                 tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)

In [15]:
def _get_clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])

In [16]:
class TransformerEncoder(nn.Module):

    def __init__(self, encoder_layer, num_layers, norm=None):
        super().__init__()
        self.layers = _get_clones(encoder_layer, num_layers)
        self.num_layers = num_layers
        self.norm = norm

    def forward(self, src,
                mask: Optional[Tensor] = None,  # 没有用mask
                src_key_padding_mask: Optional[Tensor] = None,
                pos: Optional[Tensor] = None):
        # src对应backbone最后一层输出的特征图，并且维度映射到了hidden_dim，shape是(h*w, b, hidden_dim)
        # pos对应backbone最后一层输出的特征图对应的位置编码, shape是(h*w, b, c)
        # src_key_padding_mask对应backbone最后一层输出的特征图对应的mask, shape是(b,h*w)
        output = src

        for layer in self.layers:
            output = layer(output, src_mask=mask,
                           src_key_padding_mask=src_key_padding_mask, pos=pos)

        if self.norm is not None:
            output = self.norm(output)

        return output

In [17]:
class Transformer_vis(nn.Module):

    def __init__(self, d_model=256, nhead=8, num_encoder_layers=6,dim_feedforward=2048,
                    dropout=0.1, activation="relu", normalize_before=False):
        super().__init__()

        encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
                                                dropout, activation, normalize_before)
        encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
        self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)

        self._reset_parameters()

        self.d_model = d_model
        self.nhead = nhead

    def _reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, src, mask, pos_embed):
        # flatten NxCxHxW to HWxNxC
        bs, c, h, w = src.shape  # bs, 256, 20, 20
        src = src.flatten(2).permute(2, 0, 1)  # 400, bs, 256
        pos_embed = pos_embed.flatten(2).permute(2, 0, 1)  # 400, bs, 256
        mask = mask.flatten(1)  # bs, 400

        memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)

        return memory.permute(1, 2, 0).view(bs, c, h, w)

In [18]:
args = {
    "hidden_dim": 256,
    "lr_backbone": 1e-5,
    'masks': 'store_true',
    'backbone': 'resnet50',
    'dilation': 'store_true',
    'dropout': 0.1,
    'nheads': 8,
    'dim_feedforward': 2048,
    'enc_layers': 6,
    'pre_norm': 'store_true',
    'device': "cuda"




}

In [19]:
def _get_activation_fn(activation):
    """Return an activation function given a string"""
    if activation == "relu":
        return F.relu
    if activation == "gelu":
        return F.gelu
    if activation == "glu":
        return F.glu
    raise RuntimeError(F"activation should be relu/gelu, not {activation}.")

In [20]:
def build_vis_transformer(args):
    return Transformer_vis(
        d_model=args['hidden_dim'],
        dropout=args['dropout'],
        nhead=args['nheads'],
        dim_feedforward=args['dim_feedforward'],
        num_encoder_layers=args['enc_layers'],
        normalize_before=args['pre_norm'],
    )

In [21]:
visual_model = build_vis_transformer(args)

# DETR

In [24]:
transformer = build_vis_transformer(args)

In [25]:
class DETR(nn.Module):
    """ This is the DETR module that performs object detection """
    def __init__(self, backbone, transformer):
        """ Initializes the model.
        Parameters:
            backbone: torch module of the backbone to be used. See backbone.py
            transformer: torch module of the transformer architecture. See transformer.py
            num_queries: number of object queries, ie detection slot. This is the maximal number of objects
                         DETR can detect in a single image. For COCO, we recommend 100 queries.
            """
        super().__init__()
        self.transformer = transformer
        hidden_dim = transformer.d_model
        self.input_proj = nn.Conv2d(backbone.num_channels, hidden_dim, kernel_size=1)
        self.backbone = backbone

    def forward(self, img, mask):
        """ The forward expects a NestedTensor, which consists of:
               - samples.tensor: batched images, of shape [batch_size x 3 x H x W]
               - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels
        """
        samples = NestedTensor(img, mask)
        # pos: position encoding
        features, pos = self.backbone(samples)  # pos:list, pos[-1]: [64, 256, 20, 20]
        

        src, mask = features[-1].decompose()  # src:[64, 2048, 20, 20]  mask:[64,20,20]
        assert mask is not None
        out = self.transformer(self.input_proj(src), mask, pos[-1])
        return out

In [26]:

def build_detr(args):

    device = torch.device(args['device'])
    backbone = build_backbone(args) # ResNet 50
    transformer = build_vis_transformer(args)

    model = DETR(
        backbone,
        transformer,
    )
    return model

In [30]:
visual_model = build_detr(args).cuda()

# BERT

In [35]:
from pytorch_pretrained_bert.tokenization import BertTokenizer
from pytorch_pretrained_bert.modeling import BertModel

In [38]:
args = {
    "hidden_dim": 256,
    "lr_backbone": 1e-5,
    'masks': 'store_true',
    'backbone': 'resnet50',
    'dilation': 'store_true',
    'dropout': 0.1,
    'nheads': 8,
    'dim_feedforward': 2048,
    'enc_layers': 6,
    'dec_layers': 6,
    'pre_norm': 'store_true',
    'device': "cuda",
    'bert_model': 'bert-base-uncased'

}

In [64]:
from transformers import BertTokenizer, BertModel
textmodel = BertModel.from_pretrained(args['bert_model']).cuda()
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


# Visual-linguistic Fusion model

In [71]:
class Transformer(nn.Module):

    def __init__(self, d_model=256, nhead=8, num_encoder_layers=6,
                 num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
                 activation="relu", normalize_before=False,
                 return_intermediate_dec=False):
        super().__init__()

        encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
                                                dropout, activation, normalize_before)
        encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
        self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)

        # decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,
        #                                         dropout, activation, normalize_before)
        # decoder_norm = nn.LayerNorm(d_model)
        # self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
        #                                   return_intermediate=return_intermediate_dec)

        self._reset_parameters()

        self.d_model = d_model
        self.nhead = nhead

    def _reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, src, mask, pos_embed):
        # flatten NxCxHxW to HWxNxC
        # permute NxCxW to WxNxC
        bs, c, w = src.shape
        src = src.permute(2, 0, 1)  # [441, 64, 256]
        pos_embed = pos_embed.permute(1, 0, 2)  # [441, 64, 256]
        # query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)  # [441, 64, 256]
        # mask = mask.flatten(1)  # mask:[64, 441]

        # tgt = torch.zeros_like(query_embed)  # [441, 64, 256]
        memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)  # [441, 64, 256]
        # hs = self.decoder(tgt, memory, memory_key_padding_mask=mask,
        #                   pos=pos_embed, query_pos=query_embed)
        return memory

In [72]:
class VLFusion(nn.Module):
    def __init__(self, transformer, pos):
        """ Initializes the model.
        Parameters:
            backbone: torch module of the backbone to be used. See backbone.py
            transformer: torch module of the transformer architecture. See transformer.py
            num_queries: no use
            """
        super().__init__()
        #self.num_queries = num_queries
        self.transformer = transformer
        self.pos = pos
        hidden_dim = transformer.d_model
        self.pr = nn.Embedding(1, hidden_dim)
        # self.query_embed = nn.Embedding(num_queries, hidden_dim)

        # self.v_proj = nn.Conv2d(256, hidden_dim, kernel_size=1)
        # self.l_proj = nn.Conv2d(768, hidden_dim, kernel_size=1)
        self.v_proj = torch.nn.Sequential(
          nn.Linear(256, 256),
          nn.ReLU(),)
        self.l_proj = torch.nn.Sequential(
          nn.Linear(768, 256),
          nn.ReLU(),)

    def forward(self, fv, fl):
        """ The forward expects a NestedTensor, which consists of:
               - samples.tensor: batched images, of shape [batch_size x 3 x H x W]
               - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels
        """
        bs, c, h, w = fv.shape
        _, _, l = fl.shape

        pv = self.v_proj(fv.view(bs, c, -1).permute(0,2,1))  # [bs,400,256]
        pl = self.l_proj(fl)  # [bs, 40, 256]
        pv = pv.permute(0,2,1)  # [bs,256,400]
        pl = pl.permute(0,2,1)  # [bs,256,40]

        # pv = self.v_proj(fv)  # [bs, 256, 20, 20]
        # pv = pv.view(bs, 256, -1)  # [bs, 256, 400]

        # fl = fl.unsqueeze(0).permute(0,3,1,2)  # [1, 768, bs, 40]
        # pl = self.l_proj(fl)  # [1, 256, bs, 40]
        # pl = pl.squeeze().permute(1, 0, 2).view(bs, 256, -1)  # [bs, 256, 40]

        pr = self.pr.weight # [1, 256]
        pr = pr.expand(bs,-1).unsqueeze(2)  # [bs, 256, 1]

        x0 = torch.cat((pv, pl), dim=2)
        x0 = torch.cat((x0, pr), dim=2)  # [bs, 256, 441]
        
        pos = self.pos(x0).to(x0.dtype)  # [bs, 441, 256]
        mask = torch.zeros([bs, x0.shape[2]]).cuda()
        mask = mask.bool()  # [bs, 441]
        
        out = self.transformer(x0, mask, pos)  # [441, bs, 256]
        
        return out[-1]

In [76]:
def build_transformer(args):
    return Transformer(
        d_model=args['hidden_dim'],
        dropout=args['dropout'],
        nhead=args['nheads'],
        dim_feedforward=args['dim_feedforward'],
        num_encoder_layers=args['enc_layers'],
        num_decoder_layers=args['dec_layers'],
        normalize_before=args['pre_norm'],
        # TODO: return_intermediate_dec
        return_intermediate_dec=True,
    )

In [83]:
def build_VLFusion(args):

    device = torch.device(args['device'])
    
    transformer = build_transformer(args)

    pos = build_position_encoding(args)

    model = VLFusion(
        transformer,
        pos,
    )
    return model

In [84]:
args = {
    "hidden_dim": 256,
    "lr_backbone": 1e-5,
    'masks': 'store_true',
    'backbone': 'resnet50',
    'dilation': 'store_true',
    'dropout': 0.1,
    'nheads': 8,
    'dim_feedforward': 2048,
    'enc_layers': 6,
    'dec_layers': 6,
    'pre_norm': 'store_true',
    'device': "cuda"
}

In [100]:
model = build_VLFusion(args).to("cuda")

In [103]:
img = torch.ones((8,3,216,216)).to('cuda')
lang = torch.ones((8,216, 768)).to('cuda')

In [104]:
out = model(img, lang)

RuntimeError: mat1 dim 1 must match mat2 dim 0