In [None]:
# FYI variables in backbone may need to be unfrozen

In [61]:
import jax
import flax
import flax.linen as nn
import jax.numpy as jnp
from jax import random
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
from jax_resnet import pretrained_resnest, Sequential, slice_variables

In [69]:
key = random.PRNGKey(0)
random.normal(key, (10,)).expand_dims

AttributeError: 'DeviceArray' object has no attribute 'expand_dims'

In [72]:

class DetrObjectDetectionOutput(): # Todo maybe subclass from Model output
    loss: Optional[jnp.ndarray] = None
    loss_dict: Optional[Dict] = None
    logits: jnp.ndarray = None
    pred_boxes: jnp.ndarray = None
    pred_fill: jnp.ndarray = None
    pred_rotation: jnp.ndarray = None
    auxiliary_outputs: Optional[List[Dict]] = None
    last_hidden_state: Optional[jnp.ndarray] = None
    decoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None
    decoder_attentions: Optional[Tuple[jnp.ndarray]] = None
    cross_attentions: Optional[Tuple[jnp.ndarray]] = None
    encoder_last_hidden_state: Optional[jnp.ndarray] = None
    encoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None
    encoder_attentions: Optional[Tuple[jnp.ndarray]] = None
        
        
class DetrConvModel(nn.Module):
    """
    This module adds 2D position embeddings to all intermediate feature maps of the convolutional encoder.
    """

    def __init__(self, conv_encoder, position_embedding):
        super().__init__()
        self.conv_encoder = conv_encoder
        self.position_embedding = position_embedding

    def forward(self, pixel_values, pixel_mask):
        # send pixel_values and pixel_mask through backbone to get list of (feature_map, pixel_mask) tuples
        out = self.conv_encoder(pixel_values, pixel_mask)
        pos = []
        for feature_map, mask in out:
            # position encoding
            pos.append(self.position_embedding(feature_map, mask)) # TODO Check .to(feature_map.dtype)

        return out, pos
    
def _expand_mask(mask: jnp.ndarray, dtype: jnp.dtype, tgt_len: Optional[int] = None):
    raise NotImplemented
    
    
# class DetrSinePositionEmbedding(nn.Module):
#     """
#     This is a more standard version of the position embedding, very similar to the one used by the Attention is all you
#     need paper, generalized to work on images.
#     """
#     raise NotImplemented

#     def __init__(self, embedding_dim=64, temperature=10000, normalize=False, scale=None):
#         super().__init__()
#         self.embedding_dim = embedding_dim
#         self.temperature = temperature
#         self.normalize = normalize
#         if scale is not None and normalize is False:
#             raise ValueError("normalize should be True if scale is passed")
#         if scale is None:
#             scale = 2 * math.pi
#         self.scale = scale

#     def forward(self, pixel_values, pixel_mask):
#         assert pixel_mask is not None, "No pixel mask provided"
#         y_embed = jnp.cumsum(pixel_mask, 1, dtype=jnp.float32)
#         x_embed = jnp.cumsum(pixel_mask, 2, dtype=jnp.float32)
#         if self.normalize:
#             y_embed = y_embed / (y_embed[:, -1:, :] + 1e-6) * self.scale
#             x_embed = x_embed / (x_embed[:, :, -1:] + 1e-6) * self.scale

#         dim_t = jnp.arange(stop=self.embedding_dim, dtype=jnp.float32)
#         dim_t = self.temperature ** (2 * (dim_t // 2) / self.embedding_dim)

#         pos_x = x_embed[:, :, :, None] / dim_t
#         pos_y = y_embed[:, :, :, None] / dim_t
#         pos_x = jnp.stack((jnp.sin(pos_x[:, :, :, 0::2]), jnp.cos(pos_x[:, :, :, 1::2])), dim=4).flatten(3)
#         pos_y = jnp.stack((jnp.sin(pos_y[:, :, :, 0::2]), jnp.cos(pos_y[:, :, :, 1::2])), dim=4).flatten(3)
#         pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
#         return pos


class DetrLearnedPositionEmbedding(nn.Module):
    """
    This module learns 2D positional embeddings up to a fixed maximum size.
    """

    def __init__(self, embedding_dim=256):
        super().__init__()
        self.row_embeddings = nn.Embed(50, embedding_dim)
        self.column_embeddings = nn.Embed(50, embedding_dim)

    def forward(self, pixel_values, pixel_mask=None):
        h, w = pixel_values.shape[-2:]
        i = jnp.arange(w, device=pixel_values.device)
        j = jnp.arange(h, device=pixel_values.device)
        x_emb = self.column_embeddings(i)
        y_emb = self.row_embeddings(j)
        pos = jnp.cat([jnp.repeat(jnp.expand_dims(x_emb, 0), (h, 1, 1)), jnp.repeat(jnp.expand_dims(y_emb, 1), (1, w, 1))], dim=-1)
        pos = jnp.permute(pos, (2, 0, 1))
        pos = jnp.expand_dims(pos, 0)
        pos = jnp.repeat(pos (pixel_values.shape[0], 1, 1, 1))
        return pos
    
    
def build_position_encoding(config):
    '''
    Builds the position encoding - half the size of d_model as we concat x_emb, y_emb
    '''
    n_embed = config.d_model // 2
    if config.position_embedding_type == "sine":
        # TODO find a better way of exposing other arguments
        position_embedding = DetrSinePositionEmbedding(n_embed, normalize=True)
    elif config.position_embedding_type == "learned":
        position_embedding = DetrLearnedPositionEmbedding(n_embed)
    else:
        raise ValueError(f"Not supported {config.position_embedding_type}")

    return position_embedding



class DetrAttention(nn.Module):
    """
    Multi-headed attention from 'Attention Is All You Need' paper.

    Here, we add position embeddings to the queries and keys (as explained in the DETR paper).
    """

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        dropout: float = 0.0,
        is_decoder: bool = False,
        bias: bool = True,
    ):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.dropout = dropout
        self.head_dim = embed_dim // num_heads
        assert (
            self.head_dim * num_heads == self.embed_dim
        ), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {num_heads})."
        self.scaling = self.head_dim ** -0.5

        self.k_proj = nn.Dense(embed_dim, bias=bias)
        self.v_proj = nn.Dense(embed_dim, bias=bias)
        self.q_proj = nn.Dense(embed_dim, bias=bias)
        self.out_proj = nn.Dense(embed_dim, bias=bias)

    def _shape(self, tensor: jnp.ndarray, seq_len: int, bsz: int):
        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

    def with_pos_embed(self, tensor: jnp.ndarray, position_embeddings: Optional[Tensor]):
        return tensor if position_embeddings is None else tensor + position_embeddings

    def forward(
        self,
        hidden_states: jnp.ndarray,
        attention_mask: Optional[jnp.ndarray] = None,
        position_embeddings: Optional[jnp.ndarray] = None,
        key_value_states: Optional[jnp.ndarray] = None,
        key_value_position_embeddings: Optional[jnp.ndarray] = None,
        output_attentions: bool = False,
    ) -> Tuple[jnp.ndarray, Optional[jnp.ndarray], Optional[Tuple[jnp.ndarray]]]:
        """Input shape: Batch x Time x Channel"""

        # if key_value_states are provided this layer is used as a cross-attention layer
        # for the decoder
        is_cross_attention = key_value_states is not None
        bsz, tgt_len, embed_dim = hidden_states.size()

        # add position embeddings to the hidden states before projecting to queries and keys
        if position_embeddings is not None:
            hidden_states_original = hidden_states
            hidden_states = self.with_pos_embed(hidden_states, position_embeddings)

        # add key-value position embeddings to the key value states
        if key_value_position_embeddings is not None:
            key_value_states_original = key_value_states
            key_value_states = self.with_pos_embed(key_value_states, key_value_position_embeddings)

        # get query proj
        query_states = self.q_proj(hidden_states) * self.scaling
        # get key, value proj
        if is_cross_attention:
            # cross_attentions
            key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
            value_states = self._shape(self.v_proj(key_value_states_original), -1, bsz)
        else:
            # self_attention
            key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
            value_states = self._shape(self.v_proj(hidden_states_original), -1, bsz)

        proj_shape = (bsz * self.num_heads, -1, self.head_dim)
        query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
        key_states = key_states.view(*proj_shape)
        value_states = value_states.view(*proj_shape)

        src_len = key_states.size(1)

        attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))

        if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
            raise ValueError(
                f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
            )

        if attention_mask is not None:
            if attention_mask.size() != (bsz, 1, tgt_len, src_len):
                raise ValueError(
                    f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
                )
            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)

        attn_weights = nn.functional.softmax(attn_weights, dim=-1)

        if output_attentions:
            # this operation is a bit awkward, but it's required to
            # make sure that attn_weights keeps its gradient.
            # In order to do so, attn_weights have to reshaped
            # twice and have to be reused in the following
            attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
            attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
        else:
            attn_weights_reshaped = None

        attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)

        attn_output = torch.bmm(attn_probs, value_states)

        if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
            raise ValueError(
                f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
            )

        attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
        attn_output = attn_output.transpose(1, 2)
        attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)

        attn_output = self.out_proj(attn_output)

        return attn_output, attn_weights_reshaped

In [30]:
ResNeSt50, variables = pretrained_resnest(50)
model = ResNeSt50()
idx = 18 # gives B, 7,7,2048
backbone, backbone_variables = Sequential(model.layers[0:idx]), slice_variables(variables, end=idx) 
output = backbone.apply(backbone_variables, jnp.ones((32, 224, 224, 3)),  # ImageNet sized inputs.
                  mutable=False)  # Ensure `batch_stats` aren't updated.


(32, 7, 7, 2048)

In [37]:
model

Sequential(
    # attributes
    layers = [ResNetDStem(
        # attributes
        conv_block_cls = functools.partial(<class 'jax_resnet.common.ConvBlock'>, conv_cls=<class 'flax.linen.linear.Conv'>, norm_cls=functools.partial(<class 'flax.linen.normalization.BatchNorm'>, momentum=0.9))
        stem_width = 32
        adaptive_first_width = False
    ), functools.partial(<function max_pool at 0x7fa03024caf0>, window_shape=(3, 3), strides=(2, 2), padding=((1, 1), (1, 1))), ResNeStBottleneckBlock(
        # attributes
        skip_cls = ResNeStSkipConnection
        avg_pool_first = False
        radix = 2
        splat_cls = functools.partial(<class 'jax_resnet.splat.SplAtConv2d'>, match_reference=True)
    ), ResNeStBottleneckBlock(
        # attributes
        skip_cls = ResNeStSkipConnection
        avg_pool_first = False
        radix = 2
        splat_cls = functools.partial(<class 'jax_resnet.splat.SplAtConv2d'>, match_reference=True)
    ), ResNeStBottleneckBlock(
        # att

# Optimiser
Ensure we freeze batch layers (don't put in opt) and batch stats

In [None]:
https://flax.readthedocs.io/en/latest/flax.optim.html#flax.optim.OptimizerDef.create
# Freeze both     
https://flax.readthedocs.io/en/latest/_autosummary/flax.linen.BatchNorm.html