In [None]:
# default_exp models.transformer

# Transformer model
> inspired from DETR : https://colab.research.google.com/github/facebookresearch/detr/blob/colab/notebooks/detr_demo.ipynb#scrollTo=h91rsIPl7tVl

In [None]:
#export
from fastai2.vision.all import *
from moving_mnist.models.conv_rnn import *

In [None]:
if torch.cuda.is_available():
    torch.cuda.set_device(1)
    print(torch.cuda.get_device_name())

GeForce RTX 2070 SUPER


## Encoder

In [None]:
#export
@delegates(create_cnn_model)
class Encoder(Module):
    def __init__(self, arch=resnet34, n_in=3, weights_file=None, n_out=1, strict=False, pretrained=False, **kwargs):
        "Encoder based on resnet, returns the feature map"
        model = create_cnn_model(arch, n_out=n_out, n_in=n_in, pretrained=pretrained, **kwargs)
        if weights_file is not None: 
            load_res = load_model(weights_file, model, opt=None, strict=strict)
            print(f'Loading model from file {weights_file} \n>missing keys: {load_res}')
        self.body = model[0]

    def forward(self, x):
        return self.body(x)

We can use any torchvision architecture model (resnet, vgg, inception, etc...)

In [None]:
r34_encoder = Encoder(pretrained=True)

This model encodes an image to a 512 feature space:

In [None]:
r34_encoder(torch.rand(8, 3, 128, 128)).shape

torch.Size([8, 512, 4, 4])

We recover a Tensor that has `512` channels and `(4,4)`

## DTERDemo

In [None]:
class DETRdemo(nn.Module):
    """
    Demo DETR implementation.

    Demo implementation of DETR in minimal number of lines, with the
    following differences wrt DETR in the paper:
    * learned positional encoding (instead of sine)
    * positional encoding is passed at input (instead of attention)
    * fc bbox predictor (instead of MLP)
    The model achieves ~40 AP on COCO val5k and runs at ~28 FPS on Tesla V100.
    Only batch size 1 supported.
    """
    def __init__(self, num_classes, hidden_dim=256, nheads=8,
                 num_encoder_layers=6, num_decoder_layers=6, debug=False):
        super().__init__()
        self.debug = debug
        # create ResNet-50 backbone
        self.backbone = resnet50()
        del self.backbone.fc

        # create conversion layer
        self.conv = nn.Conv2d(2048, hidden_dim, 1)

        # create a default PyTorch transformer
        self.transformer = nn.Transformer(
            hidden_dim, nheads, num_encoder_layers, num_decoder_layers)

        # prediction heads, one extra class for predicting non-empty slots
        # note that in baseline DETR linear_bbox layer is 3-layer MLP
        self.linear_class = nn.Linear(hidden_dim, num_classes + 1)
        self.linear_bbox = nn.Linear(hidden_dim, 4)

        # output positional encodings (object queries)
        self.query_pos = nn.Parameter(torch.rand(100, hidden_dim))

        # spatial positional encodings
        # note that in baseline DETR we use sine positional encodings
        self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
        self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))

    def forward(self, inputs):
        # propagate inputs through ResNet-50 up to avg-pool layer
        x = self.backbone.conv1(inputs)
        x = self.backbone.bn1(x)
        x = self.backbone.relu(x)
        x = self.backbone.maxpool(x)

        x = self.backbone.layer1(x)
        x = self.backbone.layer2(x)
        x = self.backbone.layer3(x)
        x = self.backbone.layer4(x)

        # convert from 2048 to 256 feature planes for the transformer
        h = self.conv(x)

        # construct positional encodings
        H, W = h.shape[-2:]
        pos = torch.cat([
            self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
            self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),
        ], dim=-1).flatten(0, 1).unsqueeze(1)
        if self.debug: print(f'pos: {pos.shape}')
        
        tf_input = pos + 0.1 * h.flatten(2).permute(2, 0, 1)
        if self.debug: print(f'tf_input: {tf_input.shape}')
        # propagate through the transformer
        h = self.transformer(tf_input,
                             self.query_pos.unsqueeze(1)).transpose(0, 1)
        if self.debug: print(f'tf_out: {h.shape}')
            
        # finally project transformer outputs to class labels and bounding boxes
        return {'pred_logits': self.linear_class(h), 
                'pred_boxes': self.linear_bbox(h).sigmoid()}

In [None]:
demo = DETRdemo(10, debug=True)

In [None]:
demo(torch.rand(1,3,128,128))

pos: torch.Size([16, 1, 256])
tf_input: torch.Size([16, 1, 256])
tf_out: torch.Size([1, 100, 256])


{'pred_logits': tensor([[[ 0.3864, -0.2979,  1.3025,  ..., -0.5235, -0.2826,  0.1175],
          [-0.4414, -0.5061,  1.1318,  ..., -0.4527, -0.2506,  0.5684],
          [ 0.0783, -0.2800,  1.1213,  ..., -0.4288, -0.2279,  0.2643],
          ...,
          [-0.0739, -0.3486,  1.3377,  ..., -0.9759, -0.1284,  0.9343],
          [ 0.2017, -0.4327,  1.2409,  ..., -0.9430, -0.5388,  0.3167],
          [ 0.1249, -0.1855,  1.2717,  ..., -0.5588, -0.1075,  0.5636]]],
        grad_fn=<AddBackward0>),
 'pred_boxes': tensor([[[0.3038, 0.5805, 0.6628, 0.3024],
          [0.3248, 0.4435, 0.7345, 0.3337],
          [0.3113, 0.5762, 0.7161, 0.3430],
          [0.4082, 0.5398, 0.5523, 0.3664],
          [0.3668, 0.4667, 0.6369, 0.4547],
          [0.3351, 0.4824, 0.6206, 0.5372],
          [0.3091, 0.4184, 0.6198, 0.3836],
          [0.4208, 0.4537, 0.5953, 0.4685],
          [0.3357, 0.4424, 0.6408, 0.4808],
          [0.2996, 0.3632, 0.6173, 0.3693],
          [0.2764, 0.4800, 0.7029, 0.3758],
     

## Transformer Model

We will try an architecture with an Encoder/Decoder model provided by the Transformer, instead of the ConvGRU layer.

In [None]:
#export
class DETR(Module):
    def __init__(self,  n_in=1, n_out=1, hidden_dim=256, nheads=8, num_encoder_layers=6, 
                 num_decoder_layers=6, debug=False):
        self.debug = debug
        
        #the image encoder
        self.backbone = TimeDistributed(Encoder(n_in=n_in))

        # create conversion layer
        self.conv = TimeDistributed(nn.Conv2d(512, hidden_dim, 1))

        # create a default PyTorch transformer
        self.transformer = nn.Transformer(
            hidden_dim, nheads, num_encoder_layers, num_decoder_layers)
        
        # output positional encodings (object queries)
        self.query_pos = nn.Parameter(torch.rand(100, hidden_dim))

        # spatial positional encodings
        # note that in baseline DETR we use sine positional encodings
        self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 4))
        self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 4))
        self.time_embed = nn. Parameter(torch.rand(50, hidden_dim //2))
        
        #decoder
        self.decoder = TimeDistributed(nn.Sequential(
                      UpsampleBlock(256, 128, residual=False),
                      UpsampleBlock(128, 128, residual=False),
                      UpsampleBlock(128, 64, residual=False),
                      UpsampleBlock(64, 32, residual=False),
                      UpsampleBlock(32, 16, residual=False),
                      nn.Conv2d(16, n_out, 1)
                    ))
        
    def forward(self, inputs):
        # propagate inputs through ResNet up to avg-pool layer
        x = self.backbone(inputs)
        if self.debug: print(f'backbone: {x.shape}')
            
        # convert from the latent dim to 256 feature planes for the transformer
        h = self.conv(x)
        if self.debug: print(f'h: {h.shape}')
            
        # construct positional encodings
        H, W = h.shape[-2:]
        T = h.shape[1]
        pos = torch.cat([
            self.time_embed[:T].view(T,1,1,-1).repeat(1, H, W, 1),
            self.col_embed[:W].view(1,1,W,-1).repeat(T, H, 1, 1),
            self.row_embed[:H].view(1,H,1,-1).repeat(T, 1, W, 1),
        ], dim=-1).flatten(0, 2).unsqueeze(1)
        if self.debug: print(f'pos: {pos.shape}')
        
        # propagate through the transformer
        tf_input = pos + 0.1 * h.permute(0,2,1,3,4).flatten(2).permute(2,0,1)
        if self.debug: print(f'tf_input: {tf_input.shape}')
        h = self.transformer(tf_input,
                             self.query_pos.unsqueeze(1)).transpose(0, 1)
        if self.debug: print(f'tf_out: {h.shape}')
        return self.decoder(h[:,0:H*W*T, :].view(1,T,-1,H,W))

In [None]:
detr = DETR(debug=True)

In [None]:
# h = torch.rand(1, 5, 256, 4, 4)

# h.permute(0,2,1,3,4).flatten(2).permute(2,0,1).shape

# H, W = h.shape[-2:]
# T = h.shape[1]

# h.flatten(3).shape

# T

# detr.time_embed[:T].view(T,1,1,-1).repeat(1, H, W, 1).shape

# detr.col_embed[:W].view(1,1,W,-1).repeat(T, H, 1, 1).shape

# detr.row_embed[:H].view(1,H,1,-1).repeat(T, 1, W, 1).shape

In [None]:
detr(torch.rand(1,5,1,128,128)).shape

backbone: torch.Size([1, 5, 512, 4, 4])
h: torch.Size([1, 5, 256, 4, 4])
pos: torch.Size([80, 1, 256])
tf_input: torch.Size([80, 1, 256])
tf_out: torch.Size([1, 100, 256])


torch.Size([1, 5, 1, 128, 128])

# Export -

In [None]:
# hide
from nbdev.export import *
notebook2script()

Converted 00_data.ipynb.
Converted 01_models.conv_rnn.ipynb.
Converted 02_models.transformer.ipynb.
Converted index.ipynb.
