# Feature Pyramid Networks (FPN)

## Functional API

In [1]:
import tensorflow as tf


def conv_norm(
    x: tf.Tensor,
    filters: int,
    ksize: int,
    *,
    downsample: bool = False,
    act: bool = True,
):
    """Conv2D + BatchNorm with optional ReLU activation.

    Args:
        x (tf.Tensor): Input tensor.
        filters (int): Number of output channels.
        ksize (int): Kernel size. If 1 (i.e. projection shortcut), use valid
            padding; otherwise, use same padding.
        downsample (bool): down-sampling flag i.e. set stride to 2 if
            downsample add stride 1 otherwise.
            Always do the 0 padding i.e. padding = 'same'
        act (bool): ReLU activation flag.

    Returns:
        tf.Tensor: Output tensor after the Conv2D and BatchNormalization
            operations.
    """
    strides = 2 if downsample else 1

    if ksize == 1:
        padding = "valid"
    else:
        padding = "same"

    x = tf.keras.layers.Conv2D(
        filters=filters,
        kernel_size=ksize,
        strides=strides,
        use_bias=False,
        padding=padding,
    )(x)

    x = tf.keras.layers.BatchNormalization()(x)

    if act:
        x = tf.nn.relu(x)
    return x


def bottleneck(
    x: tf.Tensor,
    filters: int,
    *,
    ds: bool = False,
    lead: bool = False,
):
    """A bottleneck residual block using the functional API.

    It consists of 3 Conv2D layers with kernel size 1, 3, 1 respectively.
    Compared to the basic residual block, it has an 1x1 projection shortcuts,
    which is used to reduce the channel dimensionality to save computation.

    Args:
        x (tf.Tensor): Input tensor.
        filters (int): Number of output channels.
        ds (bool, optional): Down-sampling flag. Defaults to False.
        lead (bool, optional): Lead flag to determine if it's the first block.
            Defaults to False.

    Returns:
        tf.Tensor: Output tensor after applying the BottleNeck operations.
    """
    # expanded output channels
    outs = filters * 4

    if lead:
        identity = conv_norm(x, outs, ksize=1, downsample=ds, act=False)
    else:
        identity = x

    x = conv_norm(x, filters, ksize=1, downsample=False, act=True)
    x = conv_norm(x, filters, ksize=3, downsample=ds, act=True)
    x = conv_norm(x, outs, ksize=1, downsample=False, act=False)
    return tf.nn.relu(identity + x)


def res_stack(
    x: tf.Tensor,
    filter: int,
    num: int,
    *,
    ds: bool = False,
) -> tf.keras.Sequential:
    """A stack of bottleneck residual blocks.

    Args:
        x (tf.Tensor): Input tensor.
        filter (int): number of output channels
        num (int): number of bottleneck blocks
        ds (bool): down-sampling flag.

    Returns:
        tf.keras.Sequential: a stack of bottleneck blocks
    """
    x = bottleneck(x, filter, ds=ds, lead=True)
    for _ in range(1, num):
        x = bottleneck(x, filter, ds=False, lead=False)

    return x


def topdown_stack(c: tf.Tensor, p: tf.Tensor) -> tf.Tensor:
    """A top-down stack of feature maps.

    1. the bottom-up feature map `c` is reduced to 256 channels using a 1x1
       Conv2D layer.
    2. the top-down feature map `p` is upsampled to the same size as `c` by
       resizing.
    3. the two feature maps are added together.
    4. the combined feature map is smoothed using a 3x3 Conv2D layer.

    Args:
        c (tf.Tensor): Input tensor from bottom-up pathway.
        p (tf.Tensor): Input tensor from top-down pathway.

    Returns:
        tf.Tensor: Output tensor after the top-down stack.
    """
    # reduce channels
    c = conv_norm(c, 256, ksize=1, downsample=False, act=True)
    _, H, W, _ = c.shape
    p = tf.image.resize(p, (H, W)) + c
    # smooth
    p = conv_norm(p, 256, ksize=3, downsample=False, act=True)
    return p


def fpn(input_shape):
    """Constructs the Feature Pyramid Network (FPN) model.

    Args:
        input_shape (tuple): Shape of the input tensor.
        num_classes (int): Number of classification classes.

    Returns:
        tf.keras.Model: The DeepResNet model.
    """

    inputs = tf.keras.layers.Input(shape=input_shape)  # (B, H, W, C)

    # Initial layers
    # shape: (B, H/2, W/2, 64)
    x = conv_norm(inputs, 64, ksize=7, downsample=True, act=True)
    # shape: (B, H/4, W/4, 64)
    x = tf.nn.max_pool2d(x, ksize=3, strides=2, padding="SAME")

    # bottom-up pathway
    c2 = res_stack(x, 64, 3, ds=False)  # (B, H/4, W/4, 256)
    c3 = res_stack(c2, 128, 4, ds=True)  # (B, H/8, W/8, 512)
    c4 = res_stack(c3, 256, 6, ds=True)  # (B, H/16, W/16, 1024)
    c5 = res_stack(c4, 512, 3, ds=True)  # (B, H/32, W/32, 2048)

    # top-down pathway
    # shape: (B, H/32, W/32, 256)
    p5 = conv_norm(c5, 256, ksize=1, downsample=False, act=True)
    p4 = topdown_stack(c4, p5)  # (B, H/16, W/16, 256)
    p3 = topdown_stack(c3, p4)  # (B, H/8, W/8, 256)
    p2 = topdown_stack(c2, p3)  # (B, H/4, W/4, 256)

    return tf.keras.Model(inputs=inputs, outputs=[p2, p3, p4, p5])

In [2]:
in_shape = (416, 416, 3)
batch = 2
x = tf.random.normal(shape=(batch, *in_shape))
model = fpn(in_shape)
model.summary()
y2, y3, y4, y5 = model(x, training=False)
assert y2.shape == (batch, 104, 104, 256)
assert y3.shape == (batch, 52, 52, 256)
assert y4.shape == (batch, 26, 26, 256)
assert y5.shape == (batch, 13, 13, 256)

2023-09-23 01:54:12.229754: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M1


Model: "model"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_1 (InputLayer)        [(None, 416, 416, 3)]        0         []                            
                                                                                                  
 conv2d (Conv2D)             (None, 208, 208, 64)         9408      ['input_1[0][0]']             
                                                                                                  
 batch_normalization (Batch  (None, 208, 208, 64)         256       ['conv2d[0][0]']              
 Normalization)                                                                                   
                                                                                                  
 tf.nn.relu (TFOpLambda)     (None, 208, 208, 64)         0         ['batch_normalization[0][0