In [None]:
# default_exp models.transformer

# Transformer model
> inspired from DETR : https://colab.research.google.com/github/facebookresearch/detr/blob/colab/notebooks/detr_demo.ipynb#scrollTo=h91rsIPl7tVl

In [None]:
#export
from fastai2.vision.all import *
from moving_mnist.models.conv_rnn import *

In [None]:
if torch.cuda.is_available():
    torch.cuda.set_device(1)
    print(torch.cuda.get_device_name())

GeForce RTX 2070 SUPER


## Encoder

In [None]:
#export
@delegates(create_cnn_model)
class Encoder(Module):
    def __init__(self, arch=resnet34, n_in=3, weights_file=None, n_out=1, strict=False, pretrained=False, **kwargs):
        "Encoder based on resnet, returns the feature map"
        model = create_cnn_model(arch, n_out=n_out, n_in=n_in, pretrained=pretrained, **kwargs)
        if weights_file is not None: 
            load_res = load_model(weights_file, model, opt=None, strict=strict)
            print(f'Loading model from file {weights_file} \n>missing keys: {load_res}')
        self.body = model[0]

    def forward(self, x):
        return self.body(x)

We can use any torchvision architecture model (resnet, vgg, inception, etc...)

In [None]:
r34_encoder = Encoder(pretrained=True)

This model encodes an image to a 512 feature space:

In [None]:
r34_encoder(torch.rand(8, 3, 128, 128)).shape

torch.Size([8, 512, 4, 4])

We recover a Tensor that has `512` channels and `(4,4)`

## DTERDemo

In [None]:
class DETRdemo(nn.Module):
    """
    Demo DETR implementation.

    Demo implementation of DETR in minimal number of lines, with the
    following differences wrt DETR in the paper:
    * learned positional encoding (instead of sine)
    * positional encoding is passed at input (instead of attention)
    * fc bbox predictor (instead of MLP)
    The model achieves ~40 AP on COCO val5k and runs at ~28 FPS on Tesla V100.
    Only batch size 1 supported.
    """
    def __init__(self, num_classes, hidden_dim=256, nheads=8,
                 num_encoder_layers=6, num_decoder_layers=6, debug=False):
        super().__init__()
        self.debug = debug
        # create ResNet-50 backbone
        self.backbone = resnet50()
        del self.backbone.fc

        # create conversion layer
        self.conv = nn.Conv2d(2048, hidden_dim, 1)

        # create a default PyTorch transformer
        self.transformer = nn.Transformer(
            hidden_dim, nheads, num_encoder_layers, num_decoder_layers)

        # prediction heads, one extra class for predicting non-empty slots
        # note that in baseline DETR linear_bbox layer is 3-layer MLP
        self.linear_class = nn.Linear(hidden_dim, num_classes + 1)
        self.linear_bbox = nn.Linear(hidden_dim, 4)

        # output positional encodings (object queries)
        self.query_pos = nn.Parameter(torch.rand(100, hidden_dim))

        # spatial positional encodings
        # note that in baseline DETR we use sine positional encodings
        self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
        self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))

    def forward(self, inputs):
        # propagate inputs through ResNet-50 up to avg-pool layer
        x = self.backbone.conv1(inputs)
        x = self.backbone.bn1(x)
        x = self.backbone.relu(x)
        x = self.backbone.maxpool(x)

        x = self.backbone.layer1(x)
        x = self.backbone.layer2(x)
        x = self.backbone.layer3(x)
        x = self.backbone.layer4(x)

        # convert from 2048 to 256 feature planes for the transformer
        h = self.conv(x)

        # construct positional encodings
        H, W = h.shape[-2:]
        pos = torch.cat([
            self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
            self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),
        ], dim=-1).flatten(0, 1).unsqueeze(1)
        if self.debug: print(f'pos: {pos.shape}')
        
        tf_input = pos + 0.1 * h.flatten(2).permute(2, 0, 1)
        if self.debug: print(f'tf_input: {tf_input.shape}')
        # propagate through the transformer
        h = self.transformer(tf_input,
                             self.query_pos.unsqueeze(1)).transpose(0, 1)
        if self.debug: print(f'tf_out: {h.shape}')
            
        # finally project transformer outputs to class labels and bounding boxes
        return {'pred_logits': self.linear_class(h), 
                'pred_boxes': self.linear_bbox(h).sigmoid()}

In [None]:
demo = DETRdemo(10, debug=True)

In [None]:
demo(torch.rand(1,3,128,128));

pos: torch.Size([16, 1, 256])
tf_input: torch.Size([16, 1, 256])
tf_out: torch.Size([1, 100, 256])


## Transformer Model

We will try an architecture with an Encoder/Decoder model provided by the Transformer, instead of the ConvGRU layer.

In [None]:
#export
class DETR(Module):
    def __init__(self,  arch=resnet34, n=80, n_in=1, n_out=1, hidden_dim=256, nheads=4, num_encoder_layers=4, 
                 num_decoder_layers=4, debug=False):
        self.debug = debug
        
        #the image encoder
        self.backbone = TimeDistributed(Encoder(arch, n_in=n_in, pretrained=True))

        # create conversion layer
        self.conv = TimeDistributed(nn.Conv2d(512, hidden_dim, 1))

        # create a default PyTorch transformer
        self.transformer = nn.Transformer(
            hidden_dim, nheads, num_encoder_layers, num_decoder_layers)
        
        # output positional encodings (object queries)
        self.query_pos = nn.Parameter(torch.rand(100, hidden_dim))

        # spatial positional encodings
        # note that in baseline DETR we use sine positional encodings
        self.pos = nn.Parameter(torch.rand(n, hidden_dim))
#         self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 4))
#         self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 4))
#         self.time_embed =nn.Parameter(torch.rand(50, hidden_dim // 2))
        
        #decoder
        self.decoder = TimeDistributed(nn.Sequential(
                           UpsampleBlock(256, 128, residual=False),
                           UpsampleBlock(128, 128, residual=False),
                           UpsampleBlock(128, 64, residual=False),
                           UpsampleBlock(64, 32, residual=False),
                           UpsampleBlock(32, 16, residual=False),
                           nn.Conv2d(16, n_out, 3,1,1))
                                      )
        self.lin = nn.Linear(100,n)  #hardcodeed
        
    def forward(self, inputs):
        # propagate inputs through ResNet up to avg-pool layer
        x = self.backbone(inputs)
        if self.debug: print(f'backbone: {x.shape}')
            
        # convert from the latent dim to 256 feature planes for the transformer
        h = self.conv(x)
        if self.debug: print(f'h: {h.shape}')
            
        # construct positional encodings
        H, W = h.shape[-2:]
        T = h.shape[1]
#         pos = torch.cat([
#             self.time_embed[:T].view(T,1,1,-1).repeat(1, H, W, 1),
#             self.col_embed[:W].view(1,1,W,-1).repeat(T, H, 1, 1),
#             self.row_embed[:H].view(1,H,1,-1).repeat(T, 1, W, 1),
#         ], dim=-1).flatten(0, 2).unsqueeze(1)
        pos = self.pos.unsqueeze(1)
        if self.debug: print(f'pos: {pos.shape}')
        
        # propagate through the transformer
        tf_input = pos + 0.1 * h.permute(0,2,1,3,4).flatten(2).permute(2,0,1)
        if self.debug: print(f'tf_input: {tf_input.shape}')
        h = self.transformer(tf_input,
                             self.query_pos.unsqueeze(1)).permute(2,1,0)
        if self.debug: print(f'tf_out: {h.shape}')
        h = self.lin(h)
        if self.debug: print(f'lin: {h.shape}')
        h = h.view(1,T,-1,H,W)
        if self.debug: print(f'before dec: {h.shape}')
        return self.decoder(h)

In [None]:
#export
def detr_split(model, stacked=False):
    if not stacked:
        return [params(model.backbone), 
                params(model.conv)+params(model.transformer)+[model.query_pos]+[model.pos]+params(model.decoder)+params(model.lin)]
    else:
        return [params(model.module.backbone), 
                params(model.module.conv)+params(model.module.transformer)+[model.module.query_pos]+[model.module.pos]+params(model.module.decoder)+params(model.module.lin)]

In [None]:
detr = DETR(debug=True)

In [None]:
split=detr_split(detr)

In [None]:
detr(torch.rand(1,5,1,128,128)).shape

backbone: torch.Size([1, 5, 512, 4, 4])
h: torch.Size([1, 5, 256, 4, 4])
pos: torch.Size([80, 1, 256])
tf_input: torch.Size([80, 1, 256])
tf_out: torch.Size([256, 1, 100])
lin: torch.Size([256, 1, 80])
before dec: torch.Size([1, 5, 256, 4, 4])


torch.Size([1, 5, 1, 128, 128])

## Integration

In [None]:
smodel = StackUnstack(detr)
imgs_list = [torch.rand(1,1,128,128) for _ in range(5)]

In [None]:
split = detr_split(smodel, True)

In [None]:
smodel(imgs_list);

backbone: torch.Size([1, 5, 512, 4, 4])
h: torch.Size([1, 5, 256, 4, 4])
pos: torch.Size([80, 1, 256])
tf_input: torch.Size([80, 1, 256])
tf_out: torch.Size([256, 1, 100])
lin: torch.Size([256, 1, 80])
before dec: torch.Size([1, 5, 256, 4, 4])


## Another Transformer
> https://github.com/maxjcohen/transformer

In [None]:
#export
from tst.transformer import Transformer

In [None]:
# Model parameters
d_model = 64 # Lattent dim
q = 8 # Query size
v = 8 # Value size
h = 8 # Number of heads
N = 4 # Number of encoder and decoder to stack
attention_size = 12 # Attention window size
dropout = 0.2 # Dropout rate
pe = None # Positional encoding
chunk_mode = None

d_input = 256 # From dataset
d_output = 256 # From dataset

In [None]:
tf = Transformer(d_input, d_model, d_output, q, v, h, N, attention_size=attention_size, dropout=dropout, chunk_mode=chunk_mode, pe=pe)

In [None]:
tf(torch.rand(8,10,256)).shape

torch.Size([8, 10, 256])

In [None]:
#export
class TransformerTS(Module):
    def __init__(self,  arch=resnet34, n_in=3, n_out=1, hidden_dim=256, debug=False):
        self.debug = debug
        
        #the image encoder
        self.backbone = TimeDistributed(Encoder(arch, n_in=n_in, pretrained=True))

        # create conversion layer
        self.conv = TimeDistributed(nn.Conv2d(512, hidden_dim, 1))

        # create a default PyTorch transformer
        q = 8 # Query size
        v = 8 # Value size
        h = 8 # Number of heads
        n = 4 # Number of encoder and decoder to stack
        attention_size = 12 # Attention window size
        dropout = 0.2 # Dropout rate
        pe = None # Positional encoding
        chunk_mode = None
        self.transformer = Transformer(hidden_dim, hidden_dim, hidden_dim, q, v, h, n, attention_size, dropout, chunk_mode, pe)

        #decoder
        self.decoder = TimeDistributed(nn.Sequential(
                           UpsampleBlock(256, 128, residual=False),
                           UpsampleBlock(128, 128, residual=False),
                           UpsampleBlock(128, 64, residual=False),
                           UpsampleBlock(64, 32, residual=False),
                           UpsampleBlock(32, 16, residual=False),
                           nn.Conv2d(16, n_out, 3,1,1))
                                      )
    def forward(self, inputs):
        # propagate inputs through ResNet up to avg-pool layer
        
        x = self.backbone(inputs)
        if self.debug: print(f'backbone: {x.shape}')
            
        # convert from the latent dim to 256 feature planes for the transformer
        h = self.conv(x)
        if self.debug: print(f'h: {h.shape}')
        bs,T,_, H,W = h.shape
        
        tf_input = h.permute(0,2,1,3,4).flatten(2).permute(0,2,1)
        if self.debug: print(f'tf_input: {tf_input.shape}')
        h = self.transformer(tf_input)
        if self.debug: print(f'tf_out: {h.shape}')
        h = h.view(bs,T,-1,H,W)
        if self.debug: print(f'before dec: {h.shape}')
        return self.decoder(h)

In [None]:
#export
def tf_split(m, stacked=False):
    if not stacked:
        return [params(m.backbone), 
                params(m.conv)+params(m.transformer)+params(m.decoder)]
    else:
        return [params(m.module.backbone), 
                params(m.module.conv)+params(m.module.transformer)+params(m.module.decoder)]
        

In [None]:
tfts = TransformerTS(debug=True)

In [None]:
tfts(torch.rand(2,5,3,128,128)).shape

backbone: torch.Size([2, 5, 512, 4, 4])
h: torch.Size([2, 5, 256, 4, 4])
tf_input: torch.Size([2, 80, 256])
tf_out: torch.Size([2, 80, 256])
before dec: torch.Size([2, 5, 256, 4, 4])


torch.Size([2, 5, 1, 128, 128])

# Export -

In [None]:
# hide
from nbdev.export import *
notebook2script()

Converted 00_data.ipynb.
Converted 01_models.conv_rnn.ipynb.
Converted 02_models.transformer.ipynb.
Converted index.ipynb.
