# Custom pooling layers

In [1]:
!pip install -q timm

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.3/2.3 MB[0m [31m8.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m21.3/21.3 MB[0m [31m19.3 MB/s[0m eta [36m0:00:00[0m
[?25h

In [21]:
import math
import random
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter

import timm

# GeM layer
Generalized Mean Pooling and Attention Pooling are obtained from this notebook[kaggle_rsna_abdominal_trauma](https://github.com/TheoViel/kaggle_rsna_abdominal_trauma/blob/cleaning/src/model_zoo/layers.py)

In [11]:
def gem(x, p=3, eps=1e-4):
    """
    Apply a Generalized Mean Pooling (GeM) as a tensor

    Args:
        x (torch.Tensor): Input a tensor of shape (batch_size, channels, height, width)
        p (float): The p-value for the generalized mean. Default is 3
        eps (float): A small constant added to the denominator to prevent division by zero. Default is 1e-3

    Returns:
        torch.Tensor: GeM-pooled representation of the input tensor
    """
    return F.avg_pool2d(x.clamp(min=eps), (x.size(-2), x.size(-1))).pow(1.0 / p)


class GeM(nn.Module):
    """
    Generalized Mean Pooling (GeM) layer for global average pooling
    Attributes:
        p (float or torch.Tensor): The p-value for the generalized mean
        eps (float): A small constant added to the denominator to prevent division by zero
    """
    def __init__(self, p=3, eps=1e-6, p_trainable=False):
        """
        Initialize the GeM layer
        Args:
            p (float or torch.Tensor): The p-value for the generalized mean
            eps (float, optional): Eps to prevent division by zero. Defaults to 1e-6
            p_trainable (bool, optional): Whether p is trainable. Defaults to False
        """
        super(GeM, self).__init__()
        if p_trainable:
            self.p = Parameter(torch.ones(1) * p)
        else:
            self.p = p
        self.eps = eps


    def forward(self, x):
        """
        Perform the GeM pooling operation on the input tensor
        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, channels, height, width)
        Returns:
            torch.Tensor: GeM-pooled representation of the input tensor
        """
        ret = gem(x, p=self.p, eps=self.eps)
        return ret


xx = torch.rand(8, 3, 224, 224)
g = GeM()
g(xx).shape
# yy = gem(xx)
# yy.shape

torch.Size([8, 3, 1, 1])

# Attention Pooling

In [20]:
class Attention(nn.Module):
    """
    Attention module for sequence data

    Attributes:
        Hidden dim (int): The dimension of the input sequence
        Attention dim (int): The dimension of the attention layer
    """
    def __init__(self, hidden_dim, attention_dim=None):
        """
        Constructor

        Args:
            hidden_dim (int): The dimension of the input sequence
            attention dim (int, optional): The dimension of the attention layer
                Defaults to None, in which case it's set to `hidden dim`
        """
        super().__init__()
        self.hidden_dim = hidden_dim
        self.attention_dim = attention_dim
        if self.attention_dim is None:
            self.attention_dim = self.hidden_dim
        # W * x + b
        self.proj_w = nn.Linear(self.hidden_dim, self.attention_dim, bias=True)
        # v.T
        self.proj_v = nn.Linear(self.attention_dim, 1, bias=False)


    def forward(self, x):
        """
        Perform the forward pass of the attention mechanism

        Args:
            x (torch.Tensor): Input sequence data of shape (batch_size, seq_len, input_dim)
        Returns:
            torch.Tensor: Attention-weighted representation of the input sequence
        """
        batch_size, seq_len, _ = x.size()
        H = torch.tanh(self.proj_w(x))
        att_scores = torch.softmax(self.proj_v(H), axis=1)
        attn_x = (x * att_scores).sum(1)
        return attn_x

x = torch.randn(8, 384, 224)
attn = Attention(hidden_dim=224)
attn(x).shape

# hidden dim == input_dim

torch.Size([8, 224])

# Sample model

In [126]:
class TimmModel(nn.Module):
    def __init__(self, backbone, pretrained=False, use_gem=False, pooling="mean"):
        super(TimmModel, self).__init__()
        self.encoder = timm.create_model(backbone, pretrained=pretrained, num_classes=0)
        self.nb_fts = self.encoder.num_features
        self.use_gem = use_gem
        self.pooling = pooling

        if self.use_gem:
            self.global_pool = GeM(p_trainable=False)
        else:
            self.global_pool = nn.AdaptiveAvgPool2d(1)

        if self.pooling == "lstm":
            self.lstm = nn.LSTM(
                self.nb_fts, self.nb_fts // 4, batch_first=True, bidirectional=True
            )
        elif self.pooling == "lstm_att":
            self.lstm = nn.LSTM(
                self.nb_fts, self.nb_fts // 2, batch_first=True, bidirectional=True
            )
            self.att = Attention(self.nb_fts)


    def extract_features(self, x):
        """
        Extract features from input images
        Args:
            x (torch.Tensor): Input images of shape [batch_size x channels x H x W]

        Returns:
            torch.Tensor: Extracted features of shape [batch_size x num_features]
        """
        fts = self.encoder.forward_features(x)

        if self.use_gem and len(fts.size()) == 4:
            fts = self.global_pool(fts)[:,:,0,0]
        else:
            while len(fts.size()) > 2:
                fts = fts.mean(-1)

        return fts


    def forward_head(self, x):
        if self.pooling == "mean":
            return x.mean(1)
        elif self.pooling == "max":
            return x.amax(1)
        elif self.pooling == "lstm":
            x, _ = self.lstm(x)
            mean = x.mean(1)
            max_ = x.amax(1)
            x = torch.cat([mean, max_], -1)
        elif self.pooling == "lstm_att":
            x, _ = self.lstm(x)
            x = self.att(x)
        return x


    def forward(self, x):
        x = self.extract_features(x)
        bs, _ = x.size()
        x = x.view(bs, -1, self.nb_fts)
        x = self.forward_head(x)
        return x

model = TimmModel(backbone='resnet18', use_gem=True, pooling="lstm_att")
model(xx).shape

torch.Size([8, 512])