In [None]:
import tensorflow as tf
from tensorflow.keras import Model
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.layers import Dense, Conv2D

class TfDETR(Model):
    def __init__(self, num_classes, hidden_dim=256, nheads=8,
                 num_encoder_layers=6, num_decoder_layers=6):
        super().__init__()
        resnet = ResNet50()
        self.backbone = Model(resnet.input, resnet.layers[-3].output)
        self.conv = Conv2D(hidden_dim, 1)
        self.transformer = None # import transformer
        self.linear_class = Dense(units=num_classes + 1)
        self.linear_bbox = Dense(units=4)
        self.query_pos = tf.Variable(tf.random.uniform((100, hidden_dim)))
        self.row_embed = tf.Variable(tf.random.uniform((50, hidden_dim // 2)))
        self.col_embed = tf.Variable(tf.random.uniform((50, hidden_dim // 2)))
    

    def call(self, x):
      # propagate inputs through ResNet-50 up to avg-pool layer
        x = self.backbone(x)

        # convert from 2048 to 256 feature planes for the transformer
        h = self.conv(x)

        # construct positional encodings
        H, W, _ = h.shape[-3:] # channels in last dimension in tf
       
        pos = tf.concat([
              tf.tile(tf.expand_dims(self.col_embed[:W], 0), (H, 1, 1)),
              tf.tile(tf.expand_dims(self.row_embed[:H], 1), (1, W, 1))            
            ], axis=-1)
        pos = tf.expand_dims(tf.reshape(pos, (-1, pos.shape[-1])), 1)
        
        # propagate through the transformer
        sp = h.shape

        h = tf.transpose(
              self.transformer(pos + 0.1 * tf.reshape(h, (sp[0], -1, sp[3])),
                             tf.expand_dims(self.query_pos, 1)), 
                         (0, 1))
        
        # # finally project transformer outputs to class labels and bounding boxes
        return {'pred_logits': self.linear_class(h), 
                'pred_boxes': tf.sigmoid(self.linear_bbox(h))}